未验证 提交 dbc08d69 编写于 作者: C chentianyu03 提交者: GitHub

modify complex template for elementwise ops (#33071)

* modify complex template for elementwise ops

* modify mul, div grad struct

* add complex template for CudaShuffleDownSync CudaShuffleXorSync funcs and fix the bug when delete cuda<9000

* fix shuffle func args bug

* fix shuffle func args bug

* fix shuffle func args bug
上级 3a7b9ed7
......@@ -20,8 +20,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
struct complex128;
struct complex64;
template <typename T>
struct complex;
} // namespace platform
} // namespace paddle
......@@ -135,9 +135,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -145,9 +145,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -159,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
// A specialization elementwise_add operator, used in gradient accumulation with
// inplace addto.
......@@ -178,9 +178,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_add)
.AddCheckpoint(
......
......@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -141,8 +140,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
......@@ -150,8 +149,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
......@@ -160,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex64>,
plat::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
......@@ -170,5 +171,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
......@@ -17,8 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -135,9 +134,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -145,9 +144,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad_grad,
......@@ -160,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_div)
.AddCheckpoint(
......
......@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -76,18 +75,21 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
}
template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
const paddle::platform::complex64* x, const paddle::platform::complex64* y,
const paddle::platform::complex64* out,
const paddle::platform::complex64* dout, int64_t size,
paddle::platform::complex64* dx, paddle::platform::complex64* dy) {
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* x,
const paddle::platform::complex<float>* y,
const paddle::platform::complex<float>* out,
const paddle::platform::complex<float>* dout, int64_t size,
paddle::platform::complex<float>* dx,
paddle::platform::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex64 o = dout[col];
paddle::platform::complex64 y_conj(y[col].real, -y[col].imag);
paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
paddle::platform::complex<float> o = dout[col];
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
......@@ -95,19 +97,21 @@ __global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
}
template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex128>(
const paddle::platform::complex128* x,
const paddle::platform::complex128* y,
const paddle::platform::complex128* out,
const paddle::platform::complex128* dout, int64_t size,
paddle::platform::complex128* dx, paddle::platform::complex128* dy) {
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
const paddle::platform::complex<double>* x,
const paddle::platform::complex<double>* y,
const paddle::platform::complex<double>* out,
const paddle::platform::complex<double>* dout, int64_t size,
paddle::platform::complex<double>* dx,
paddle::platform::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex128 o = dout[col];
paddle::platform::complex128 y_conj(y[col].real, -y[col].imag);
paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
paddle::platform::complex<double> o = dout[col];
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
......@@ -145,9 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -157,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -173,6 +177,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -74,23 +74,13 @@ struct DivGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
};
template <>
struct DivGradDX<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 y_conj(y.real, -y.imag);
return dout / y_conj;
}
};
template <>
struct DivGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
template <typename T>
struct DivGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> y_conj(y.real, -y.imag);
return dout / y_conj;
}
};
......@@ -102,23 +92,13 @@ struct DivGradDY {
}
};
template <>
struct DivGradDY<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag);
return -dout * out_div_y_conj;
}
};
template <>
struct DivGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 out_div_y_conj((out / y).real,
template <typename T>
struct DivGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> out_div_y_conj((out / y).real,
-(out / y).imag);
return -dout * out_div_y_conj;
}
......
......@@ -16,8 +16,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -134,9 +133,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -144,9 +143,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -158,9 +157,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
......
......@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -76,31 +75,31 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
}
template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex64>(
const plat::complex64* x, const plat::complex64* y,
const plat::complex64* out, const plat::complex64* dout, int64_t size,
plat::complex64* dx, plat::complex64* dy) {
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<float>>(
const plat::complex<float>* x, const plat::complex<float>* y,
const plat::complex<float>* out, const plat::complex<float>* dout,
int64_t size, plat::complex<float>* dx, plat::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex64 o = dout[col];
dx[col] = plat::complex64(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex64(x[col].real, -x[col].imag) * o;
plat::complex<float> o = dout[col];
dx[col] = plat::complex<float>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex<float>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
}
}
template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex128>(
const plat::complex128* x, const plat::complex128* y,
const plat::complex128* out, const plat::complex128* dout, int64_t size,
plat::complex128* dx, plat::complex128* dy) {
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>(
const plat::complex<double>* x, const plat::complex<double>* y,
const plat::complex<double>* out, const plat::complex<double>* dout,
int64_t size, plat::complex<double>* dx, plat::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex128 o = dout[col];
dx[col] = plat::complex128(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex128(x[col].real, -x[col].imag) * o;
plat::complex<double> o = dout[col];
dx[col] = plat::complex<double>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex<double>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
}
}
......@@ -133,8 +132,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
......@@ -142,8 +141,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
......@@ -152,6 +153,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex64>,
plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
plat::complex<double>>);
......@@ -132,23 +132,13 @@ struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
};
template <>
struct MulGradDX<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
template <>
struct MulGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
template <typename T>
struct MulGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
......@@ -158,23 +148,13 @@ struct MulGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
};
template <>
struct MulGradDY<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
template <>
struct MulGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 x_conj(x.real, -x.imag);
template <typename T>
struct MulGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
......
......@@ -20,8 +20,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
struct complex128;
struct complex64;
template <typename T>
struct complex;
} // namespace platform
} // namespace paddle
......@@ -134,9 +134,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -144,9 +144,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -158,9 +158,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_sub)
.AddCheckpoint(
......
......@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -103,9 +102,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -115,9 +114,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -129,6 +128,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -16,8 +16,7 @@ limitations under the License. */
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -82,28 +81,52 @@ __forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val,
#endif
}
// CUDA 9.0 have native compatible float16 shfl_down
#if defined(PADDLE_WITH_HIP)
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
float16 val, int delta,
int width) {
#ifdef PADDLE_WITH_HIP
return float16(__shfl_down(static_cast<float>(val),
static_cast<unsigned>(delta), width));
#else
return float16(
__shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
#endif
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
float real = __shfl_down(val.real, delta, width);
float imag = __shfl_down(val.imag, delta, width);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask, paddle::platform::complex<double> val,
int delta, int width) {
double real = __shfl_down(val.real, delta, width);
double imag = __shfl_down(val.imag, delta, width);
return paddle::platform::complex<double>(real, imag);
}
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val, int width) {
#ifdef PADDLE_WITH_HIP
return float16(__shfl_xor(static_cast<float>(val), width));
#else
return float16(__shfl_xor(static_cast<half>(val), width));
#endif
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
float real = __shfl_xor(val.real, width);
float imag = __shfl_xor(val.imag, width);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<double> val, int width) {
double real = __shfl_xor(val.real, width);
double imag = __shfl_xor(val.imag, width);
return paddle::platform::complex<double>(real, imag);
}
#else
template <>
......@@ -115,25 +138,26 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
}
template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleDownSync(
unsigned mask, paddle::platform::complex64 val, int delta, int width) {
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
float real = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width));
float imag = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width));
return paddle::platform::complex64(real, imag);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleDownSync(
unsigned mask, paddle::platform::complex128 val, int delta, int width) {
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask, paddle::platform::complex<double> val,
int delta, int width) {
double real = static_cast<double>(
__shfl_down_sync(mask, static_cast<double>(val.real),
static_cast<unsigned>(delta), width));
double imag = static_cast<double>(
__shfl_down_sync(mask, static_cast<double>(val.imag),
static_cast<unsigned>(delta), width));
return paddle::platform::complex128(real, imag);
return paddle::platform::complex<double>(real, imag);
}
template <>
......@@ -143,23 +167,23 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
}
template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleXorSync(
unsigned mask, paddle::platform::complex64 val, int width) {
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
float real = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.real), width));
float imag = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.imag), width));
return paddle::platform::complex64(real, imag);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleXorSync(
unsigned mask, paddle::platform::complex128 val, int width) {
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<double> val, int width) {
double real = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.real), width));
double imag = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.imag), width));
return paddle::platform::complex128(real, imag);
return paddle::platform::complex<double>(real, imag);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册