未验证 提交 1a1a2ce8 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP] Add elementwise add xpu after phi, test=develop (#39787)

* [KP] Add elementwise add xpu, test=develop

* modify the File Permissions

* modify the copyright time

* modify code style

* modify code style
上级 b7bcd0f6
...@@ -25,7 +25,7 @@ repos: ...@@ -25,7 +25,7 @@ repos:
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./tools/codestyle/clang_format.hook -i entry: bash ./tools/codestyle/clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps)$
- repo: local - repo: local
hooks: hooks:
- id: cpplint-cpp-source - id: cpplint-cpp-source
...@@ -48,7 +48,7 @@ repos: ...@@ -48,7 +48,7 @@ repos:
name: copyright_checker name: copyright_checker
entry: python ./tools/codestyle/copyright.hook entry: python ./tools/codestyle/copyright.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude: | exclude: |
(?x)^( (?x)^(
paddle/utils/.* paddle/utils/.*
......
...@@ -125,6 +125,9 @@ function(op_library TARGET) ...@@ -125,6 +125,9 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.xpu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.xpu)
list(APPEND xpu_kp_cc_srcs ${TARGET}.xpu) list(APPEND xpu_kp_cc_srcs ${TARGET}.xpu)
endif() endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps)
list(APPEND xpu_kp_cc_srcs ${TARGET}.kps)
endif()
endif() endif()
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}") string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
...@@ -162,6 +165,8 @@ function(op_library TARGET) ...@@ -162,6 +165,8 @@ function(op_library TARGET)
list(APPEND xpu_cc_srcs ${src}) list(APPEND xpu_cc_srcs ${src})
elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$") elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$")
list(APPEND xpu_kp_cc_srcs ${src}) list(APPEND xpu_kp_cc_srcs ${src})
elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.kps$")
list(APPEND xpu_kp_cc_srcs ${src})
elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
list(APPEND npu_cc_srcs ${src}) list(APPEND npu_cc_srcs ${src})
elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$") elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$")
...@@ -384,7 +389,15 @@ function(op_library TARGET) ...@@ -384,7 +389,15 @@ function(op_library TARGET)
# pybind USE_OP_DEVICE_KERNEL for XPU KP # pybind USE_OP_DEVICE_KERNEL for XPU KP
if (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0) if (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, KP);\n") foreach(xpu_kp_src ${xpu_kp_cc_srcs})
set(op_name "")
find_register(${xpu_kp_src} "REGISTER_OP_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, KP);\n")
message(STATUS "Building KP Target: ${op_name}")
set(pybind_flag 1)
endif()
endforeach()
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for NPU # pybind USE_OP_DEVICE_KERNEL for NPU
......
...@@ -17,7 +17,7 @@ if(NOT WITH_XPU_KP) ...@@ -17,7 +17,7 @@ if(NOT WITH_XPU_KP)
endif() endif()
if(NOT XPU_TOOLCHAIN) if(NOT XPU_TOOLCHAIN)
set(XPU_TOOLCHAIN /workspace/paddle/xpu-demo/XTDK) set(XPU_TOOLCHAIN /workspace/output/XTDK-ubuntu_x86_64)
get_filename_component(XPU_TOOLCHAIN ${XPU_TOOLCHAIN} REALPATH) get_filename_component(XPU_TOOLCHAIN ${XPU_TOOLCHAIN} REALPATH)
endif() endif()
if(NOT IS_DIRECTORY ${XPU_TOOLCHAIN}) if(NOT IS_DIRECTORY ${XPU_TOOLCHAIN})
...@@ -102,7 +102,7 @@ macro(compile_kernel COMPILE_ARGS) ...@@ -102,7 +102,7 @@ macro(compile_kernel COMPILE_ARGS)
set(XTDK_DIR ${XPU_TOOLCHAIN}) set(XTDK_DIR ${XPU_TOOLCHAIN})
set(CXX_DIR ${HOST_SYSROOT}) set(CXX_DIR ${HOST_SYSROOT})
set(XPU_CXX_FLAGS -Wno-error=pessimizing-move -Wno-error=constant-conversion -Wno-error=c++11-narrowing -Wno-error=shift-count-overflow -Wno-error=unused-local-typedef -Wno-error=deprecated-declarations -Wno-deprecated-declarations -std=c++14 -m64 -fPIC -fno-omit-frame-pointer -Wall -Wno-inconsistent-missing-override -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=unused-local-typedefs -Wno-error=ignored-attributes -Wno-error=int-in-bool-context -Wno-error=parentheses -Wno-error=address -Wno-ignored-qualifiers -Wno-ignored-attributes -Wno-parentheses -DNDEBUG ) set(XPU_CXX_FLAGS -fforce-enable-int128 -Wno-error=pessimizing-move -Wno-error=constant-conversion -Wno-error=c++11-narrowing -Wno-error=shift-count-overflow -Wno-error=unused-local-typedef -Wno-error=deprecated-declarations -Wno-deprecated-declarations -std=c++14 -m64 -fPIC -fno-omit-frame-pointer -Wall -Wno-inconsistent-missing-override -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=unused-local-typedefs -Wno-error=ignored-attributes -Wno-error=int-in-bool-context -Wno-error=parentheses -Wno-error=address -Wno-ignored-qualifiers -Wno-ignored-attributes -Wno-parentheses -DNDEBUG )
#include path #include path
get_property(dirs DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) get_property(dirs DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
...@@ -127,9 +127,11 @@ macro(compile_kernel COMPILE_ARGS) ...@@ -127,9 +127,11 @@ macro(compile_kernel COMPILE_ARGS)
kernel_build/${kernel_name}.bin.o kernel_build/${kernel_name}.bin.o
COMMAND COMMAND
${CMAKE_COMMAND} -E make_directory kernel_build ${CMAKE_COMMAND} -E make_directory kernel_build
COMMAND
cp ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu
COMMAND COMMAND
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
-I. -o kernel_build/${kernel_name}.bin.o.sec ${kernel_path}/${kernel_name}.xpu -I. -o kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.xpu
--xpu-device-only -c -v --xpu-device-only -c -v
COMMAND COMMAND
${XTDK_DIR}/bin/xpu2-elfconv kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.bin.o ${XPU_CLANG} --sysroot=${CXX_DIR} ${XTDK_DIR}/bin/xpu2-elfconv kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.bin.o ${XPU_CLANG} --sysroot=${CXX_DIR}
...@@ -148,9 +150,11 @@ macro(compile_kernel COMPILE_ARGS) ...@@ -148,9 +150,11 @@ macro(compile_kernel COMPILE_ARGS)
kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.host.o
COMMAND COMMAND
${CMAKE_COMMAND} -E make_directory kernel_build ${CMAKE_COMMAND} -E make_directory kernel_build
COMMAND
cp ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu
COMMAND COMMAND
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
-I. -o kernel_build/${kernel_name}.host.o ${kernel_path}/${kernel_name}.xpu -I. -o kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.xpu
--xpu-host-only -c -v --xpu-host-only -c -v
WORKING_DIRECTORY WORKING_DIRECTORY
${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}
...@@ -185,7 +189,7 @@ macro(xpu_add_library TARGET_NAME) ...@@ -185,7 +189,7 @@ macro(xpu_add_library TARGET_NAME)
# Distinguish .xpu file from other files # Distinguish .xpu file from other files
foreach(cur_xpu_src IN LISTS xpu_srcs_lists) foreach(cur_xpu_src IN LISTS xpu_srcs_lists)
get_filename_component(language_type_name ${cur_xpu_src} EXT) get_filename_component(language_type_name ${cur_xpu_src} EXT)
if(${language_type_name} STREQUAL ".xpu") if(${language_type_name} STREQUAL ".kps")
list(APPEND xpu_kernel_lists ${cur_xpu_src}) list(APPEND xpu_kernel_lists ${cur_xpu_src})
else() else()
list(APPEND cc_kernel_lists ${cur_xpu_src}) list(APPEND cc_kernel_lists ${cur_xpu_src})
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Please do not modify the following code
#if defined(__CUDA_ARCH__)
#undef __CUDA_ARCH__
#endif
#if defined(__CUDACC__)
#undef __CUDACC__
#endif
#if defined(__CUDA__)
#undef __CUDA__
#endif
#if defined(__NVCC__)
#undef __NVCC__
#endif
#ifdef PADDLE_WITH_XPU_KP
#include <xpu/runtime.h> // NOLINT
#include "xpu/kernel/cluster_header.h" // NOLINT
#include "xpu/kernel/debug.h" // NOLINT
#include "xpu/kernel/math.h" // NOLINT
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseAddXPUKPKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& xpu_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T, kps::AddFunctor<T>, 1>(
xpu_ctx, ins, &outs, axis, kps::AddFunctor<T>());
}
};
static std::vector<int> get_rdims(const std::vector<int>& xdims,
const std::vector<int>& ydims) {
std::vector<int> rdims;
for (size_t i = 0; i < xdims.size(); i++) {
if (xdims[i] != ydims[i]) {
rdims.push_back(i);
}
}
return rdims;
}
template <typename T>
class ElementwiseAddGradXPUKPKernel : public ElemwiseGradKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims();
const framework::DDim& dz_dims = dz->dims();
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
std::vector<int> z_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
for (int i = 0; i < max_dim; i++) {
z_dims_vec[i] = dz_dims[i];
}
std::vector<int> rdims_for_x;
std::vector<int> rdims_for_y;
rdims_for_x = get_rdims(x_dims_vec, z_dims_vec);
rdims_for_y = get_rdims(y_dims_vec, z_dims_vec);
const T* dz_data = dz->data<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
if (dx != nullptr) {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (rdims_for_x.size() == 0) {
if (dx_data != dz_data) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dz, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dz)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx_data), z_dims_vec, rdims_for_x);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
}
}
if (dy != nullptr) {
T* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (rdims_for_y.size() == 0) {
if (dy_data != dz_data) {
framework::TensorCopy(
*dz, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
} else {
int ret = xpu::reduce_sum<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy_data), z_dims_vec, rdims_for_y);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum ");
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(elementwise_add, KP, plat::XPUPlace,
ops::ElementwiseAddXPUKPKernel<float>);
REGISTER_OP_KERNEL(elementwise_add_grad, KP, plat::XPUPlace,
ops::ElementwiseAddGradXPUKPKernel<float>);
#endif // PADDLE_WITH_XPU_KP
...@@ -74,7 +74,10 @@ limitations under the License. */ ...@@ -74,7 +74,10 @@ limitations under the License. */
#include "paddle/fluid/platform/device/device_ext.h" #include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/fluid/platform/device/stream.h" #include "paddle/fluid/platform/device/stream.h"
#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
#endif
namespace Eigen { namespace Eigen {
struct DefaultDevice; struct DefaultDevice;
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifdef __xpu_kp__ #if defined(__xpu__)
#include <xpu/runtime.h> #include <xpu/runtime.h>
#include "xpu/kernel/cluster_header.h" #include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h" #include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h" #include "xpu/kernel/math.h"
#endif #endif
#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__)) #if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu__))
#define HOSTDEVICE __host__ __device__ #define HOSTDEVICE __host__ __device__
#define DEVICE __device__ #define DEVICE __device__
#define HOST __host__ #define HOST __host__
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
namespace kps = phi::kps; namespace kps = phi::kps;
...@@ -122,7 +122,7 @@ struct DimensionsTransform { ...@@ -122,7 +122,7 @@ struct DimensionsTransform {
explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins, explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
const phi::DDim &dims, const phi::DDim &dims,
int axis) { int axis) {
const int N = max(static_cast<int>(ins.size()), 2); const int N = std::max(static_cast<int>(ins.size()), 2);
dim_size = dims.size(); dim_size = dims.size();
out_dims = phi::vectorize<int64_t>(dims); out_dims = phi::vectorize<int64_t>(dims);
in_dims.resize(N); in_dims.resize(N);
...@@ -183,7 +183,7 @@ struct DimensionsTransform { ...@@ -183,7 +183,7 @@ struct DimensionsTransform {
} }
}; };
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T, int VecSize, int Rank, bool IsBoundary = false> template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData( __device__ __forceinline__ void LoadData(
...@@ -268,7 +268,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -268,7 +268,7 @@ __global__ void VectorizedBroadcastKernel(
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU_KP
for (; block_offset < main_offset; block_offset += stride) { for (; block_offset < main_offset; block_offset += stride) {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<InT,
OutT, OutT,
...@@ -348,12 +348,12 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -348,12 +348,12 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
phi::Array<_ptr_ OutT *, NumOuts> outs_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) { for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = ctx.Alloc<OutT>((*outs)[i]); outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
} }
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = (_ptr_ InT *)(ins[i]->data<InT>()); ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
if (use_broadcast[i]) { if (use_broadcast[i]) {
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
...@@ -363,7 +363,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -363,7 +363,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
} }
} }
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU_KP
const int threads = 64; const int threads = 64;
const int blocks = 8; const int blocks = 8;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#ifndef __xpu__
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
...@@ -435,3 +437,5 @@ HOSTDEVICE inline float16 maxi(const float16& a, const float16& b) { ...@@ -435,3 +437,5 @@ HOSTDEVICE inline float16 maxi(const float16& a, const float16& b) {
} // namespace numext } // namespace numext
} // namespace Eigen } // namespace Eigen
#endif // __xpu__
...@@ -21,12 +21,13 @@ limitations under the License. */ ...@@ -21,12 +21,13 @@ limitations under the License. */
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h" #include "paddle/fluid/platform/function_traits.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h"
#define HOSTDEVICE __host__ __device__
namespace kps = phi::kps; namespace kps = phi::kps;
#endif #endif
...@@ -436,7 +437,7 @@ inline void ElementwiseGradPreProcess(const DenseTensor &dout, ...@@ -436,7 +437,7 @@ inline void ElementwiseGradPreProcess(const DenseTensor &dout,
} }
} }
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
// static unroller // static unroller
template <template <int Index, int VecSize> typename Func, template <template <int Index, int VecSize> typename Func,
...@@ -469,10 +470,14 @@ struct Loader { ...@@ -469,10 +470,14 @@ struct Loader {
kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f)); kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
if (is_boundary) { if (is_boundary) {
kps::ReadData<Type, VecSize, 1, 1, ArgsT, Index, true>( kps::ReadData<Type, VecSize, 1, 1, ArgsT, Index, true>(
args, reinterpret_cast<const Type *>(in[Index]) + data_offset, num); args,
reinterpret_cast<const _ptr_ Type *>(in[Index]) + data_offset,
num);
} else { } else {
kps::ReadData<Type, VecSize, 1, 1, ArgsT, Index, false>( kps::ReadData<Type, VecSize, 1, 1, ArgsT, Index, false>(
args, reinterpret_cast<const Type *>(in[Index]) + data_offset, num); args,
reinterpret_cast<const _ptr_ Type *>(in[Index]) + data_offset,
num);
} }
} }
}; };
...@@ -482,8 +487,7 @@ struct InputSetter { ...@@ -482,8 +487,7 @@ struct InputSetter {
template <typename Array> template <typename Array>
static HOSTDEVICE void Apply( static HOSTDEVICE void Apply(
const std::vector<const DenseTensor *> &ins_tensor, Array *ins_data) { const std::vector<const DenseTensor *> &ins_tensor, Array *ins_data) {
(*ins_data)[Index] = (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data());
reinterpret_cast<const _ptr_ char *>(ins_tensor[Index]->data());
} }
}; };
...@@ -718,9 +722,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx, ...@@ -718,9 +722,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
Unroller<InputSetter, VecSize, Arity>::step(ins, &ins_data); Unroller<InputSetter, VecSize, Arity>::step(ins, &ins_data);
for (int i = 0; i < NumOuts; ++i) { for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = ctx.Alloc<OutT>((*outs)[i]); outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
} }
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU_KP
int block_size = 64; int block_size = 64;
int grid_size = 8; int grid_size = 8;
auto stream = ctx.x_context()->xpu_stream; auto stream = ctx.x_context()->xpu_stream;
......
...@@ -114,6 +114,7 @@ inline void ComputeBroadcastKernelSize(int *x_dims_array, ...@@ -114,6 +114,7 @@ inline void ComputeBroadcastKernelSize(int *x_dims_array,
} }
} }
#ifndef __xpu__
template <typename T, typename OP, typename Tout = T> template <typename T, typename OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x, static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x,
const T *y, const T *y,
...@@ -128,8 +129,8 @@ static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x, ...@@ -128,8 +129,8 @@ static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x,
bool is_xsize, bool is_xsize,
OP op, OP op,
T *dd) { T *dd) {
int tid = threadIdx.x; int tid = THREAD_ID_X;
int bid = blockIdx.x; int bid = BLOCK_ID_X;
T val(0); T val(0);
if (is_xsize) { if (is_xsize) {
...@@ -196,8 +197,8 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( ...@@ -196,8 +197,8 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
DY_OP dy_op, DY_OP dy_op,
T *dx, T *dx,
T *dy) { T *dy) {
int tid = threadIdx.x; int tid = THREAD_ID_X;
int bid = blockIdx.x; int bid = BLOCK_ID_X;
T val(0); T val(0);
if (is_xsize_larger) { if (is_xsize_larger) {
...@@ -260,67 +261,67 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight(const T *x, ...@@ -260,67 +261,67 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight(const T *x,
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0); T val(0);
size_t width_stride = gridDim.x * blockDim.x; size_t width_stride = GRID_NUM_X * BLOCK_NUM_X;
size_t idx = threadIdx.x + blockDim.x * blockIdx.x; size_t idx = THREAD_ID_X + BLOCK_NUM_X * BLOCK_ID_X;
size_t full_width = size_t full_width =
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height = size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
if (is_y) { if (is_y) {
for (int m = idx; m < full_width; m += width_stride) { for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0; sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
int out_offset = n * w + m; int out_offset = n * w + m;
int x_offset = (n % x_h) * x_w + m % x_w; int x_offset = (n % x_h) * x_w + m % x_w;
if (dy) { if (dy) {
if (m < w && n < h) { if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]); T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]);
sdata[threadIdx.y][threadIdx.x] += val; sdata[THREAD_ID_Y][THREAD_ID_X] += val;
} }
__syncthreads(); __syncthreads();
} }
} }
if (dy) { if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y]; T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
for (int i = warpSize >> 1; i > 0; i >>= 1) { for (int i = warpSize >> 1; i > 0; i >>= 1) {
my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
} }
__syncthreads(); __syncthreads();
if ((threadIdx.x == 0)) { if ((THREAD_ID_X == 0)) {
sdata[0][threadIdx.y] = my_val; sdata[0][THREAD_ID_Y] = my_val;
} }
__syncthreads(); __syncthreads();
if (threadIdx.y == 0 && m < w) { if (THREAD_ID_Y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x]; dy[m] = sdata[0][THREAD_ID_X];
} }
} }
} }
} else { } else {
for (int m = idx; m < full_width; m += width_stride) { for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0; sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
int out_offset = n * w + m; int out_offset = n * w + m;
int y_offset = (n % x_h) * x_w + m % x_w; int y_offset = (n % x_h) * x_w + m % x_w;
if (dy) { if (dy) {
if (m < w && n < h) { if (m < w && n < h) {
T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]); T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]);
sdata[threadIdx.y][threadIdx.x] += val; sdata[THREAD_ID_Y][THREAD_ID_X] += val;
} }
__syncthreads(); __syncthreads();
} }
} }
if (dy) { if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y]; T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
for (int i = warpSize >> 1; i > 0; i >>= 1) { for (int i = warpSize >> 1; i > 0; i >>= 1) {
my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
} }
__syncthreads(); __syncthreads();
if ((threadIdx.x == 0)) { if ((THREAD_ID_X == 0)) {
sdata[0][threadIdx.y] = my_val; sdata[0][THREAD_ID_Y] = my_val;
} }
__syncthreads(); __syncthreads();
if (threadIdx.y == 0 && m < w) { if (THREAD_ID_Y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x]; dy[m] = sdata[0][THREAD_ID_X];
} }
} }
} }
...@@ -339,9 +340,9 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x, ...@@ -339,9 +340,9 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x,
int x_h, int x_h,
int x_w, int x_w,
bool is_y) { bool is_y) {
int j = blockIdx.x; int j = BLOCK_ID_X;
int i = threadIdx.x; int i = THREAD_ID_X;
int tid = threadIdx.x; int tid = THREAD_ID_X;
T val(0); T val(0);
if (is_y) { if (is_y) {
...@@ -357,7 +358,7 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x, ...@@ -357,7 +358,7 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x,
if (dy) { if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h); val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dy[j] = val; dy[j] = val;
} }
} }
...@@ -374,7 +375,7 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x, ...@@ -374,7 +375,7 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x,
if (dy) { if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h); val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dy[j] = val; dy[j] = val;
} }
} }
...@@ -393,9 +394,9 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x, ...@@ -393,9 +394,9 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x,
DY_OP dy_op, DY_OP dy_op,
T *dx, T *dx,
T *dy) { T *dy) {
int j = blockIdx.x; int j = BLOCK_ID_X;
int i = threadIdx.x; int i = THREAD_ID_X;
int tid = threadIdx.x; int tid = THREAD_ID_X;
T val(0); T val(0);
if (is_xsize_larger) { if (is_xsize_larger) {
do { do {
...@@ -412,7 +413,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x, ...@@ -412,7 +413,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x,
if (dy) { if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h); val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dy[j] = val; dy[j] = val;
} }
} }
...@@ -431,7 +432,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x, ...@@ -431,7 +432,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x,
if (dx) { if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h); val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dx[j] = val; dx[j] = val;
} }
} }
...@@ -456,16 +457,16 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( ...@@ -456,16 +457,16 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0); T val(0);
size_t width_stride = gridDim.x * blockDim.x; size_t width_stride = GRID_NUM_X * BLOCK_NUM_X;
size_t idx = threadIdx.x + blockDim.x * blockIdx.x; size_t idx = THREAD_ID_X + BLOCK_NUM_X * BLOCK_ID_X;
size_t full_width = size_t full_width =
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height = size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
if (is_xsize_larger) { if (is_xsize_larger) {
for (int m = idx; m < full_width; m += width_stride) { for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0; sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m; int x_offset = n * w + m;
if (dx && m < w && n < h) { if (dx && m < w && n < h) {
dx[x_offset] = dx[x_offset] =
...@@ -474,29 +475,29 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( ...@@ -474,29 +475,29 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
if (dy) { if (dy) {
if (m < w && n < h) { if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val; sdata[THREAD_ID_Y][THREAD_ID_X] += val;
} }
__syncthreads(); __syncthreads();
} }
} }
if (dy) { if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y]; T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
for (int i = warpSize >> 1; i > 0; i >>= 1) for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads(); __syncthreads();
if ((threadIdx.x == 0)) { if ((THREAD_ID_X == 0)) {
sdata[0][threadIdx.y] = my_val; sdata[0][THREAD_ID_Y] = my_val;
} }
__syncthreads(); __syncthreads();
if (threadIdx.y == 0 && m < w) { if (THREAD_ID_Y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x]; dy[m] = sdata[0][THREAD_ID_X];
} }
} }
} }
} else { // x.dims < y.dims, broadcast for x. } else { // x.dims < y.dims, broadcast for x.
for (int m = idx; m < full_width; m += width_stride) { for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0; sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
int y_offset = n * w + m; int y_offset = n * w + m;
if (dy && m < w && n < h) { if (dy && m < w && n < h) {
dy[y_offset] = dy[y_offset] =
...@@ -505,22 +506,22 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( ...@@ -505,22 +506,22 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
if (dx) { if (dx) {
if (m < w && n < h) { if (m < w && n < h) {
T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
sdata[threadIdx.y][threadIdx.x] += val; sdata[THREAD_ID_Y][THREAD_ID_X] += val;
} }
__syncthreads(); __syncthreads();
} }
} }
if (dx) { if (dx) {
T my_val = sdata[threadIdx.x][threadIdx.y]; T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
for (int i = warpSize >> 1; i > 0; i >>= 1) for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads(); __syncthreads();
if ((threadIdx.x == 0)) { if ((THREAD_ID_X == 0)) {
sdata[0][threadIdx.y] = my_val; sdata[0][THREAD_ID_Y] = my_val;
} }
__syncthreads(); __syncthreads();
if (threadIdx.y == 0 && m < w) { if (THREAD_ID_Y == 0 && m < w) {
dx[m] = sdata[0][threadIdx.x]; dx[m] = sdata[0][THREAD_ID_X];
} }
} }
} }
...@@ -540,8 +541,8 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x, ...@@ -540,8 +541,8 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x,
DY_OP dy_op, DY_OP dy_op,
T *dx, T *dx,
T *dy) { T *dy) {
int tid = threadIdx.x; int tid = THREAD_ID_X;
int j = blockIdx.x; int j = BLOCK_ID_X;
T val(0); T val(0);
int ttid = tid; int ttid = tid;
...@@ -569,7 +570,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x, ...@@ -569,7 +570,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x,
int h = pre * post; int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h); val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dy[j] = val; dy[j] = val;
} }
} }
...@@ -596,7 +597,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x, ...@@ -596,7 +597,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x,
int h = pre * post; int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h); val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dx[j] = val; dx[j] = val;
} }
} }
...@@ -668,9 +669,9 @@ __global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array, ...@@ -668,9 +669,9 @@ __global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array,
int thread_num, int thread_num,
DX_OP dx_op) { DX_OP dx_op) {
T val(0); T val(0);
int i = blockIdx.x; int i = BLOCK_ID_X;
int tid = threadIdx.x; int tid = THREAD_ID_X;
for (int j = tid; j < thread_num; j += blockDim.x) { for (int j = tid; j < thread_num; j += BLOCK_NUM_X) {
const int X_index = i * thread_num + j; const int X_index = i * thread_num + j;
int out_index = X_index; int out_index = X_index;
int C_index = 0; int C_index = 0;
...@@ -694,7 +695,7 @@ __global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array, ...@@ -694,7 +695,7 @@ __global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array,
val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]); val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]);
} }
val = paddle::platform::reduceSum(val, tid, thread_num); val = paddle::platform::reduceSum(val, tid, thread_num);
if (threadIdx.x == 0) { if (THREAD_ID_X == 0) {
dx[i] = val; dx[i] = val;
} }
} }
...@@ -1416,8 +1417,8 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, ...@@ -1416,8 +1417,8 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
template <typename T> template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel( static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) { const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
int stride = gridDim.x * blockDim.x; int stride = GRID_NUM_X * BLOCK_NUM_X;
int loop = size / vec_size; int loop = size / vec_size;
int remainder = size % vec_size; int remainder = size % vec_size;
const float4 *dout_vec = reinterpret_cast<const float4 *>(dout); const float4 *dout_vec = reinterpret_cast<const float4 *>(dout);
...@@ -1544,14 +1545,14 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout, ...@@ -1544,14 +1545,14 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout,
int64_t size, int64_t size,
T *dx, T *dx,
T *dy) { T *dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x; int col = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
while (col < size) { while (col < size) {
if (dx != nullptr) { if (dx != nullptr) {
dx[col] = dout[col]; dx[col] = dout[col];
} }
dy[col] = -dout[col]; dy[col] = -dout[col];
col += blockDim.x * gridDim.x; col += BLOCK_NUM_X * GRID_NUM_X;
} }
} }
...@@ -1629,4 +1630,6 @@ void elementwise_sub_grad(const GPUContext &ctx, ...@@ -1629,4 +1630,6 @@ void elementwise_sub_grad(const GPUContext &ctx,
dy->mutable_data<T>(ctx.GetPlace())); dy->mutable_data<T>(ctx.GetPlace()));
} }
#endif
} // namespace phi } // namespace phi
...@@ -328,7 +328,7 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -328,7 +328,7 @@ __device__ __forceinline__ void Reduce(T* out,
const T* in, const T* in,
ReduceFunctor reducer, ReduceFunctor reducer,
bool reduce_last_dim) { bool reduce_last_dim) {
if (Mode == kGlobalMode) { if (Mode == details::kGlobalMode) {
#pragma unroll #pragma unroll
for (int i = 0; i < NY; ++i) { for (int i = 0; i < NY; ++i) {
#pragma unroll #pragma unroll
...@@ -336,7 +336,7 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -336,7 +336,7 @@ __device__ __forceinline__ void Reduce(T* out,
out[i] = reducer(out[i], in[i * NX + j]); out[i] = reducer(out[i], in[i * NX + j]);
} }
} }
BlockXReduce<T, OpFunc, NY>(out, reducer); BlockXReduce<T, ReduceFunctor, NY>(out, reducer);
} else { // else kLocalMode } else { // else kLocalMode
#pragma unroll #pragma unroll
for (int i = 0; i < NY; ++i) { for (int i = 0; i < NY; ++i) {
......
...@@ -34,9 +34,9 @@ struct alignas(sizeof(T) * VecSize) VectorType { ...@@ -34,9 +34,9 @@ struct alignas(sizeof(T) * VecSize) VectorType {
#pragma pack(4) #pragma pack(4)
template <int kDims> template <int kDims>
struct BroadcastConfig { struct BroadcastConfig {
int strides_in[DDim::kMaxRank]; int strides_in[phi::DDim::kMaxRank];
int strides_out[DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank];
int in_dim[DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank];
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
...@@ -222,7 +222,7 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data) { ...@@ -222,7 +222,7 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
* src: The data pointer of the current block. * src: The data pointer of the current block.
* size: The current block needs to load size data continuously. * size: The current block needs to load size data continuously.
*/ */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false> template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst, __device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int num) { int num) {
...@@ -251,9 +251,9 @@ template <typename T, ...@@ -251,9 +251,9 @@ template <typename T,
int BlockSize, int BlockSize,
typename ArgsT, typename ArgsT,
int Index, int Index,
bool IsBoundary = false> bool IsBoundary>
__device__ __forceinline__ void ReadData(ArgsT* dst, __device__ __forceinline__ void ReadData(ArgsT* dst,
const T* __restrict__ src, const T _global_ptr_* src,
int num) { int num) {
int thread_offset = core_id() * NX; int thread_offset = core_id() * NX;
__local__ T in_temp[1]; __local__ T in_temp[1];
...@@ -366,21 +366,24 @@ __device__ __inline__ void ReadDataBc(T* dst, ...@@ -366,21 +366,24 @@ __device__ __inline__ void ReadDataBc(T* dst,
* reduce_last_dim: Used to indicate whether the dimension of reduce contains * reduce_last_dim: Used to indicate whether the dimension of reduce contains
* the lowest dimension. * the lowest dimension.
*/ */
template <typename T, template <typename Tx,
typename Ty,
int NX, int NX,
int NY, int NY,
int BlockSize, int BlockSize,
int Rank, int Rank,
typename IndexCal, typename IndexCal,
typename Functor,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __inline__ void ReadDataReduce(T* dst, __device__ __forceinline__ void ReadDataReduce(Ty* dst,
const T _global_ptr_* src, const Tx* __restrict__ src,
int block_offset, int block_offset,
const IndexCal& index_cal, const IndexCal& index_cal,
int size_nx, int size_nx,
int size_ny, int size_ny,
int stride_nx, int stride_nx,
int stride_ny, int stride_ny,
Functor func,
bool reduce_last_dim) { bool reduce_last_dim) {
__local__ Tx in_temp[1]; __local__ Tx in_temp[1];
int thread_offset = 0; int thread_offset = 0;
...@@ -618,10 +621,11 @@ template <typename T, ...@@ -618,10 +621,11 @@ template <typename T,
int BlockSize, int BlockSize,
int Rank, int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst, __device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, const details::BroadcastConfig<Rank>& config,
int total_num_output) { int total_num_output) {
int thread_offset = block_offset + core_id() * NX; int thread_offset = block_offset + core_id() * NX;
int index_src = 0; int index_src = 0;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
namespace phi {
namespace kps {
/**
* @brief Default unary identity functor
*/
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
inline IdentityFunctor() {}
explicit inline IdentityFunctor(int n) {}
inline Ty operator()(const Tx& x) const { return static_cast<Ty>(x); }
__device__ inline IdentityFunctor() {}
__device__ explicit inline IdentityFunctor(int n) {}
__device__ inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
}
__device__ inline void SetDiv(int n) {}
};
/**
* @brief Default unary div functor. Divide by a constant
*/
template <typename Tx, typename Ty = Tx>
struct DivideFunctor {
inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); }
explicit inline DivideFunctor(int n)
: n_inv(static_cast<Tx>(((float)1.0) / (static_cast<float>(n)))) {}
inline Ty operator()(const Tx& x) const { return static_cast<Ty>(x * n_inv); }
__device__ inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); }
__device__ inline DivideFunctor(int n)
: n_inv(static_cast<Tx>(((float)1.0) / (static_cast<float>(n)))) {}
__device__ inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x * n_inv);
}
__device__ inline void SetDiv(int n) {
n_inv = static_cast<Tx>(((float)1.0) / (static_cast<float>(n)));
}
private:
Tx n_inv;
};
/**
* @brief Default unary square functor
*/
template <typename Tx, typename Ty = Tx>
struct SquareFunctor {
HOSTDEVICE inline SquareFunctor() {}
HOSTDEVICE explicit inline SquareFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x);
}
};
/****************************** Binary Functor ********************************/
/**
* @brief Default binary min functor
*/
template <typename T>
struct MinFunctor {
inline T initial() { /*return static_cast<T>(std::numeric_limits<T>::max());*/
}
__device__ T operator()(const T& a, const T& b) const {
return (b < a) ? b : a;
}
};
/**
* @brief Default binary max functor
*/
template <typename T>
struct MaxFunctor {
inline T initial() {
// return static_cast<T>(std::numeric_limits<T>::lowest());
}
__device__ T operator()(const T& a, const T& b) const {
return (b > a) ? b : a;
}
};
/**
* @brief Default binary add functor
*/
template <typename T>
struct AddFunctor {
inline T initial() { return static_cast<T>(0.0f); }
__device__ T operator()(const T a, const T b) const { return b + a; }
};
/**
* @brief Default binary add functor
*/
template <typename T>
struct MulFunctor {
inline T initial() { return static_cast<T>(1.0f); }
__device__ T operator()(const T& a, const T& b) const { return b * a; }
};
/**
* @brief Default binary logic or functor
*/
template <typename T>
struct LogicalOrFunctor {
inline T initial() { return static_cast<T>(false); }
__device__ T operator()(const T& a, const T& b) const { return b || a; }
};
/**
* @brief Default binary logic and functor
*/
template <typename T>
struct LogicalAndFunctor {
inline T initial() { return static_cast<T>(true); }
__device__ T operator()(const T& a, const T& b) const { return b && a; }
};
/**
* @brief Default binary sub functor
*/
template <typename T>
struct SubFunctor {
inline T initial() { return static_cast<T>(0.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; }
};
/**
* @brief Default binary div functor
*/
template <typename T, typename Enable = void>
struct DivFunctor {
inline T initial() { return static_cast<T>(1.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
};
template <typename T>
struct DivFunctor<T,
typename std::enable_if<std::is_integral<T>::value>::type> {
inline T initial() { return static_cast<T>(1.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
// For int32/int64, need to check whether the divison is zero.
PADDLE_ENFORCE_NE(b,
0,
phi::errors::InvalidArgument(
"Integer division by zero encountered "
"in (floor) divide. Please check the input value."));
return a / b;
}
};
/**
* @brief Default binary floor divide functor
*/
template <typename T>
struct FloorDivFunctor {
inline T initial() { return static_cast<T>(1.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
PADDLE_ENFORCE_NE(b,
0,
phi::errors::InvalidArgument(
"Integer division by zero encountered "
"in (floor) divide. Please check the input value."));
return static_cast<T>(std::trunc(a / b));
}
};
} // namespace kps
} // namespace phi
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace phi { namespace phi {
namespace kps { namespace kps {
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU_KP
struct dim3 { struct dim3 {
int x; int x;
int y; int y;
......
...@@ -14,11 +14,7 @@ ...@@ -14,11 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/kernels/primitive/helper_primitives.h" #include "paddle/phi/kernels/primitive/helper_primitives.h"
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h"
#define KPStream XPUStream #define KPStream XPUStream
#define KPDevice phi::XPUContext #define KPDevice phi::XPUContext
...@@ -26,6 +22,11 @@ ...@@ -26,6 +22,11 @@
#define __forceinline__ __inline__ #define __forceinline__ __inline__
#define __restrict__ #define __restrict__
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/kernels/primitive/compute_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/datamover_primitives_xpu2.h"
#include "paddle/phi/kernels/primitive/functor_primitives_xpu2.h"
#define THREAD_ID_X core_id() #define THREAD_ID_X core_id()
#define THREAD_ID_Y 0 #define THREAD_ID_Y 0
#define THREAD_ID_Z 0 #define THREAD_ID_Z 0
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest, skip_check_grad_ci
from op_test_xpu import XPUOpTest
import unittest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp(XPUOpTest):
def setUp(self):
self.op_type = "elementwise_add"
self.init_dtype()
self.init_input_output()
self.init_axis()
self.init_max_relative_error()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': self.out}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad_normal(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'],
'Out',
max_relative_error=self.max_relative_error)
def test_check_grad_ingore_x(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Y'],
'Out',
no_grad_set=set("X"),
max_relative_error=self.max_relative_error)
def test_check_grad_ingore_y(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'],
'Out',
no_grad_set=set("Y"),
max_relative_error=self.max_relative_error)
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.add(self.x, self.y)
def init_dtype(self):
self.dtype = np.float32
def init_axis(self):
self.axis = -1
def init_max_relative_error(self):
self.max_relative_error = 0.006
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseAddOp_scalar(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1,1) to test broadcast.")
class TestElementwiseAddOp_scalar2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
self.out = self.x + self.y
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_Vector(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
self.y = np.random.random((100, )).astype(self.dtype)
self.out = np.add(self.x, self.y)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1)
def init_axis(self):
self.axis = 0
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 100, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 100, 1)
def init_axis(self):
self.axis = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1, 100)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12, 1)
def init_axis(self):
self.axis = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3, 4).astype(self.dtype)
self.y = np.random.rand(100, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1, 1)
def init_axis(self):
self.axis = 0
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_5(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 3, 12).astype(self.dtype)
self.y = np.random.rand(10, 1, 12).astype(self.dtype)
self.out = self.x + self.y
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)
self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype)
self.out = self.x + self.y
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype)
self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype)
self.out = self.x + self.y
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12)
def init_axis(self):
self.axis = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
def init_axis(self):
self.axis = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_channelwise_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(1, 1, 100).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype)
self.y = np.random.rand(10, 1, 12, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 12).astype(self.dtype)
self.y = np.random.rand(2, 3, 10, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 2
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0))
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1)
# the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestAddOp(unittest.TestCase):
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
y_1 = paddle.add(x, y, name='add_res')
self.assertEqual(('add_res' in y_1.name), True)
def test_declarative(self):
with fluid.program_guard(fluid.Program()):
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = paddle.add(x, y)
place = fluid.XPUPlace(0)
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
z_expected = np.array([3., 8., 6.])
self.assertEqual((z_value == z_expected).all(), True)
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([2, 3, 4]).astype('float32')
np_y = np.array([1, 5, 2]).astype('float32')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.add(x, y)
np_z = z.numpy()
z_expected = np.array([3., 8., 6.])
self.assertEqual((np_z == z_expected).all(), True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册