未验证 提交 dbc08d69 编写于 作者: C chentianyu03 提交者: GitHub

modify complex template for elementwise ops (#33071)

* modify complex template for elementwise ops

* modify mul, div grad struct

* add complex template for CudaShuffleDownSync CudaShuffleXorSync funcs and fix the bug when delete cuda<9000

* fix shuffle func args bug

* fix shuffle func args bug

* fix shuffle func args bug
上级 3a7b9ed7
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct complex128; template <typename T>
struct complex64; struct complex;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -135,9 +135,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -135,9 +135,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -145,9 +145,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -145,9 +145,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad_grad, elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
...@@ -159,9 +159,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -159,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>, int64_t>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
// A specialization elementwise_add operator, used in gradient accumulation with // A specialization elementwise_add operator, used in gradient accumulation with
// inplace addto. // inplace addto.
...@@ -178,9 +178,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -178,9 +178,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_add) REGISTER_OP_VERSION(elementwise_add)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -141,8 +140,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -141,8 +140,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>); ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
...@@ -150,8 +149,10 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -150,8 +149,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>); plat::complex<float>>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad_grad, elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
...@@ -160,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -160,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex64>, plat::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>); plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>, grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
...@@ -170,5 +171,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -170,5 +171,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>); ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
...@@ -17,8 +17,7 @@ limitations under the License. */ ...@@ -17,8 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -135,9 +134,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -135,9 +134,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -145,9 +144,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -145,9 +144,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div_grad_grad, elementwise_div_grad_grad,
...@@ -160,9 +159,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -160,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>, int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_div) REGISTER_OP_VERSION(elementwise_div)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -14,8 +14,7 @@ limitations under the License. */ ...@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -76,18 +75,21 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, ...@@ -76,18 +75,21 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
} }
template <> template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>( __global__ void
const paddle::platform::complex64* x, const paddle::platform::complex64* y, SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
const paddle::platform::complex64* out, const paddle::platform::complex<float>* x,
const paddle::platform::complex64* dout, int64_t size, const paddle::platform::complex<float>* y,
paddle::platform::complex64* dx, paddle::platform::complex64* dy) { const paddle::platform::complex<float>* out,
const paddle::platform::complex<float>* dout, int64_t size,
paddle::platform::complex<float>* dx,
paddle::platform::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x; int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) { while (col < size) {
paddle::platform::complex64 o = dout[col]; paddle::platform::complex<float> o = dout[col];
paddle::platform::complex64 y_conj(y[col].real, -y[col].imag); paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real, paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag); -(out[col] / y[col]).imag);
dx[col] = o / y_conj; dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj; dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x; col += blockDim.x * gridDim.x;
...@@ -95,19 +97,21 @@ __global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>( ...@@ -95,19 +97,21 @@ __global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
} }
template <> template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex128>( __global__ void
const paddle::platform::complex128* x, SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
const paddle::platform::complex128* y, const paddle::platform::complex<double>* x,
const paddle::platform::complex128* out, const paddle::platform::complex<double>* y,
const paddle::platform::complex128* dout, int64_t size, const paddle::platform::complex<double>* out,
paddle::platform::complex128* dx, paddle::platform::complex128* dy) { const paddle::platform::complex<double>* dout, int64_t size,
paddle::platform::complex<double>* dx,
paddle::platform::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x; int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) { while (col < size) {
paddle::platform::complex128 o = dout[col]; paddle::platform::complex<double> o = dout[col];
paddle::platform::complex128 y_conj(y[col].real, -y[col].imag); paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real, paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag); -(out[col] / y[col]).imag);
dx[col] = o / y_conj; dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj; dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x; col += blockDim.x * gridDim.x;
...@@ -145,9 +149,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -145,9 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -157,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -157,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad, elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
...@@ -173,6 +177,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -173,6 +177,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>, int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -74,23 +74,13 @@ struct DivGradDX { ...@@ -74,23 +74,13 @@ struct DivGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
}; };
template <> template <typename T>
struct DivGradDX<paddle::platform::complex64> { struct DivGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex64 operator()( HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y, paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const { paddle::platform::complex<T> out,
paddle::platform::complex64 y_conj(y.real, -y.imag); paddle::platform::complex<T> dout) const {
return dout / y_conj; paddle::platform::complex<T> y_conj(y.real, -y.imag);
}
};
template <>
struct DivGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
return dout / y_conj; return dout / y_conj;
} }
}; };
...@@ -102,23 +92,13 @@ struct DivGradDY { ...@@ -102,23 +92,13 @@ struct DivGradDY {
} }
}; };
template <> template <typename T>
struct DivGradDY<paddle::platform::complex64> { struct DivGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex64 operator()( HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y, paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const { paddle::platform::complex<T> out,
paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag); paddle::platform::complex<T> dout) const {
return -dout * out_div_y_conj; paddle::platform::complex<T> out_div_y_conj((out / y).real,
}
};
template <>
struct DivGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 out_div_y_conj((out / y).real,
-(out / y).imag); -(out / y).imag);
return -dout * out_div_y_conj; return -dout * out_div_y_conj;
} }
......
...@@ -16,8 +16,7 @@ limitations under the License. */ ...@@ -16,8 +16,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -134,9 +133,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -134,9 +133,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -144,9 +143,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -144,9 +143,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad, elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
...@@ -158,9 +157,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -158,9 +157,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>, int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_mul) REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -14,8 +14,7 @@ limitations under the License. */ ...@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -76,31 +75,31 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, ...@@ -76,31 +75,31 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
} }
template <> template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex64>( __global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<float>>(
const plat::complex64* x, const plat::complex64* y, const plat::complex<float>* x, const plat::complex<float>* y,
const plat::complex64* out, const plat::complex64* dout, int64_t size, const plat::complex<float>* out, const plat::complex<float>* dout,
plat::complex64* dx, plat::complex64* dy) { int64_t size, plat::complex<float>* dx, plat::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x; int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) { while (col < size) {
plat::complex64 o = dout[col]; plat::complex<float> o = dout[col];
dx[col] = plat::complex64(y[col].real, -y[col].imag) * o; dx[col] = plat::complex<float>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex64(x[col].real, -x[col].imag) * o; dy[col] = plat::complex<float>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x; col += blockDim.x * gridDim.x;
} }
} }
template <> template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex128>( __global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>(
const plat::complex128* x, const plat::complex128* y, const plat::complex<double>* x, const plat::complex<double>* y,
const plat::complex128* out, const plat::complex128* dout, int64_t size, const plat::complex<double>* out, const plat::complex<double>* dout,
plat::complex128* dx, plat::complex128* dy) { int64_t size, plat::complex<double>* dx, plat::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x; int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) { while (col < size) {
plat::complex128 o = dout[col]; plat::complex<double> o = dout[col];
dx[col] = plat::complex128(y[col].real, -y[col].imag) * o; dx[col] = plat::complex<double>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex128(x[col].real, -x[col].imag) * o; dy[col] = plat::complex<double>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x; col += blockDim.x * gridDim.x;
} }
} }
...@@ -133,8 +132,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -133,8 +132,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex64>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex128>); ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
...@@ -142,8 +141,10 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -142,8 +141,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex64>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex128>); plat::complex<float>>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad, elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
...@@ -152,6 +153,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -152,6 +153,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex64>, plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>); plat::complex<double>>);
...@@ -132,23 +132,13 @@ struct MulGradDX { ...@@ -132,23 +132,13 @@ struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
}; };
template <> template <typename T>
struct MulGradDX<paddle::platform::complex64> { struct MulGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex64 operator()( HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y, paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const { paddle::platform::complex<T> out,
paddle::platform::complex64 y_conj(y.real, -y.imag); paddle::platform::complex<T> dout) const {
return dout * y_conj; paddle::platform::complex<T> y_conj(y.real, -y.imag);
}
};
template <>
struct MulGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
return dout * y_conj; return dout * y_conj;
} }
}; };
...@@ -158,23 +148,13 @@ struct MulGradDY { ...@@ -158,23 +148,13 @@ struct MulGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
}; };
template <> template <typename T>
struct MulGradDY<paddle::platform::complex64> { struct MulGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex64 operator()( HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y, paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const { paddle::platform::complex<T> out,
paddle::platform::complex64 x_conj(x.real, -x.imag); paddle::platform::complex<T> dout) const {
return dout * x_conj; paddle::platform::complex<T> x_conj(x.real, -x.imag);
}
};
template <>
struct MulGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 x_conj(x.real, -x.imag);
return dout * x_conj; return dout * x_conj;
} }
}; };
......
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct complex128; template <typename T>
struct complex64; struct complex;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -134,9 +134,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -134,9 +134,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -144,9 +144,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -144,9 +144,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad_grad, elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
...@@ -158,9 +158,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -158,9 +158,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>, int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_sub) REGISTER_OP_VERSION(elementwise_sub)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -14,8 +14,7 @@ limitations under the License. */ ...@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -103,9 +102,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -103,9 +102,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -115,9 +114,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -115,9 +114,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad_grad, elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
...@@ -129,6 +128,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -129,6 +128,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>, int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -16,8 +16,7 @@ limitations under the License. */ ...@@ -16,8 +16,7 @@ limitations under the License. */
// NOTE(): support float16 to half in header file. // NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16 #define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -82,28 +81,52 @@ __forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val, ...@@ -82,28 +81,52 @@ __forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val,
#endif #endif
} }
// CUDA 9.0 have native compatible float16 shfl_down
#if defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_HIP)
template <> template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
float16 val, int delta, float16 val, int delta,
int width) { int width) {
#ifdef PADDLE_WITH_HIP
return float16(__shfl_down(static_cast<float>(val), return float16(__shfl_down(static_cast<float>(val),
static_cast<unsigned>(delta), width)); static_cast<unsigned>(delta), width));
#else
return float16(
__shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
#endif
} }
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
float real = __shfl_down(val.real, delta, width);
float imag = __shfl_down(val.imag, delta, width);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask, paddle::platform::complex<double> val,
int delta, int width) {
double real = __shfl_down(val.real, delta, width);
double imag = __shfl_down(val.imag, delta, width);
return paddle::platform::complex<double>(real, imag);
}
template <> template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val, int width) { float16 val, int width) {
#ifdef PADDLE_WITH_HIP
return float16(__shfl_xor(static_cast<float>(val), width)); return float16(__shfl_xor(static_cast<float>(val), width));
#else }
return float16(__shfl_xor(static_cast<half>(val), width));
#endif template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
float real = __shfl_xor(val.real, width);
float imag = __shfl_xor(val.imag, width);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<double> val, int width) {
double real = __shfl_xor(val.real, width);
double imag = __shfl_xor(val.imag, width);
return paddle::platform::complex<double>(real, imag);
} }
#else #else
template <> template <>
...@@ -115,25 +138,26 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, ...@@ -115,25 +138,26 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
} }
template <> template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleDownSync( __forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex64 val, int delta, int width) { unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
float real = static_cast<float>(__shfl_down_sync( float real = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width)); mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width));
float imag = static_cast<float>(__shfl_down_sync( float imag = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width)); mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width));
return paddle::platform::complex64(real, imag); return paddle::platform::complex<float>(real, imag);
} }
template <> template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleDownSync( __forceinline__ __device__ paddle::platform::complex<double>
unsigned mask, paddle::platform::complex128 val, int delta, int width) { CudaShuffleDownSync(unsigned mask, paddle::platform::complex<double> val,
int delta, int width) {
double real = static_cast<double>( double real = static_cast<double>(
__shfl_down_sync(mask, static_cast<double>(val.real), __shfl_down_sync(mask, static_cast<double>(val.real),
static_cast<unsigned>(delta), width)); static_cast<unsigned>(delta), width));
double imag = static_cast<double>( double imag = static_cast<double>(
__shfl_down_sync(mask, static_cast<double>(val.imag), __shfl_down_sync(mask, static_cast<double>(val.imag),
static_cast<unsigned>(delta), width)); static_cast<unsigned>(delta), width));
return paddle::platform::complex128(real, imag); return paddle::platform::complex<double>(real, imag);
} }
template <> template <>
...@@ -143,23 +167,23 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, ...@@ -143,23 +167,23 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
} }
template <> template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleXorSync( __forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex64 val, int width) { unsigned mask, paddle::platform::complex<float> val, int width) {
float real = static_cast<float>( float real = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.real), width)); __shfl_xor_sync(mask, static_cast<float>(val.real), width));
float imag = static_cast<float>( float imag = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.imag), width)); __shfl_xor_sync(mask, static_cast<float>(val.imag), width));
return paddle::platform::complex64(real, imag); return paddle::platform::complex<float>(real, imag);
} }
template <> template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleXorSync( __forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex128 val, int width) { unsigned mask, paddle::platform::complex<double> val, int width) {
double real = static_cast<double>( double real = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.real), width)); __shfl_xor_sync(mask, static_cast<double>(val.real), width));
double imag = static_cast<double>( double imag = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.imag), width)); __shfl_xor_sync(mask, static_cast<double>(val.imag), width));
return paddle::platform::complex128(real, imag); return paddle::platform::complex<double>(real, imag);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册