未验证 提交 6c399d94 编写于 作者: C chentianyu03 提交者: GitHub

Modify Ops from complex64/128 to complex<float/double> types. (#33133)

* modify kron OP to complex template types

* modify reshape, slice, trace, transpose OPs to complex template types

* modify to complex template types in eigen slice files

* change to complex template types for pad.cc and pac.cu

* format code style
上级 6a5b7e59
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -56,8 +55,8 @@ INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float);
INSTANTIATION(EigenPad, double);
INSTANTIATION(EigenPad, platform::complex64);
INSTANTIATION(EigenPad, platform::complex128);
INSTANTIATION(EigenPad, platform::complex<float>);
INSTANTIATION(EigenPad, platform::complex<double>);
#undef INSTANTIATION
} // namespace operators
......
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -58,8 +57,8 @@ INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float);
INSTANTIATION(EigenPad, double);
INSTANTIATION(EigenPad, platform::float16);
INSTANTIATION(EigenPad, platform::complex64);
INSTANTIATION(EigenPad, platform::complex128);
INSTANTIATION(EigenPad, platform::complex<float>);
INSTANTIATION(EigenPad, platform::complex<double>);
#undef INSTANTIATION
} // namespace operators
......
......@@ -14,8 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -69,8 +67,6 @@ INSTANTIATION(EigenSlice, float);
INSTANTIATION(EigenSlice, double);
INSTANTIATION(EigenSlice, platform::float16);
INSTANTIATION(EigenSlice, platform::bfloat16);
INSTANTIATION(EigenSlice, platform::complex64);
INSTANTIATION(EigenSlice, platform::complex128);
INSTANTIATION(EigenSlice, platform::complex<float>);
INSTANTIATION(EigenSlice, platform::complex<double>);
#undef INSTANTIATION
......
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -58,8 +57,8 @@ INSTANTIATION(EigenSlice, int64_t);
INSTANTIATION(EigenSlice, float);
INSTANTIATION(EigenSlice, double);
INSTANTIATION(EigenSlice, platform::float16);
INSTANTIATION(EigenSlice, platform::complex64);
INSTANTIATION(EigenSlice, platform::complex128);
INSTANTIATION(EigenSlice, platform::complex<float>);
INSTANTIATION(EigenSlice, platform::complex<double>);
#undef INSTANTIATION
} // namespace operators
......
......@@ -18,8 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -185,9 +184,9 @@ REGISTER_OP_CPU_KERNEL(
ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
REGISTER_OP_CPU_KERNEL(
......@@ -198,6 +197,6 @@ REGISTER_OP_CPU_KERNEL(
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -26,9 +25,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -38,6 +37,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -26,9 +26,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
// Process an element in the output, used with a parallel-for
template <typename T>
struct KronElemFunctor {
......@@ -175,72 +172,13 @@ struct KronGradElemFunctor {
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex64> {
KronGradElemFunctor(const complex64* dout, const complex64* A,
const complex64* B, complex64* dout_a, complex64* dout_b,
const int64_t* stride_dout, const int64_t* stride_a,
const int64_t* stride_b, const int64_t* shape_b,
const int64_t numel_a, const int64_t numel_b,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex64* dout_;
const complex64* A_;
const complex64* B_;
complex64* dout_a_;
complex64* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex128> {
KronGradElemFunctor(const complex128* dout, const complex128* A,
const complex128* B, complex128* dout_a,
complex128* dout_b, const int64_t* stride_dout,
template <typename T>
struct KronGradElemFunctor<platform::complex<T>> {
KronGradElemFunctor(const platform::complex<T>* dout,
const platform::complex<T>* A,
const platform::complex<T>* B,
platform::complex<T>* dout_a,
platform::complex<T>* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
......@@ -273,21 +211,23 @@ struct KronGradElemFunctor<complex128> {
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag);
dout_[idx] *
platform::complex<T>(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag);
dout_[idx] *
platform::complex<T>(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex128* dout_;
const complex128* A_;
const complex128* B_;
complex128* dout_a_;
complex128* dout_b_;
const platform::complex<T>* dout_;
const platform::complex<T>* A_;
const platform::complex<T>* B_;
platform::complex<T>* dout_a_;
platform::complex<T>* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
......
......@@ -613,23 +613,24 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t,
ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel,
paddle::platform::bfloat16, ops::ReshapeKernel, paddle::platform::complex64,
ops::ReshapeKernel, paddle::platform::complex128, ops::ReshapeKernel);
paddle::platform::bfloat16, ops::ReshapeKernel,
paddle::platform::complex<float>, ops::ReshapeKernel,
paddle::platform::complex<double>, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel,
paddle::platform::complex64, ops::ReshapeGradKernel,
paddle::platform::complex128, ops::ReshapeGradKernel);
paddle::platform::complex<float>, ops::ReshapeGradKernel,
paddle::platform::complex<double>, ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16,
ops::ReshapeDoubleGradKernel, paddle::platform::complex64,
ops::ReshapeDoubleGradKernel, paddle::platform::complex128,
ops::ReshapeDoubleGradKernel, paddle::platform::complex<float>,
ops::ReshapeDoubleGradKernel, paddle::platform::complex<double>,
ops::ReshapeDoubleGradKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -650,22 +651,23 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel);
plat::complex<float>, ops::ReshapeKernel,
plat::complex<double>, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex64,
ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel);
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex<float>,
ops::ReshapeGradKernel, plat::complex<double>, ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel,
plat::float16, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, plat::complex64, ops::ReshapeDoubleGradKernel,
plat::complex128, ops::ReshapeDoubleGradKernel);
ops::ReshapeDoubleGradKernel, plat::complex<float>,
ops::ReshapeDoubleGradKernel, plat::complex<double>,
ops::ReshapeDoubleGradKernel);
#endif
#ifdef PADDLE_WITH_XPU
......@@ -673,14 +675,14 @@ REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel);
plat::complex<float>, ops::ReshapeKernel,
plat::complex<double>, ops::ReshapeKernel);
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, plat::complex64,
ops::ReshapeGradKernel, plat::complex128,
ops::ReshapeGradKernel, plat::complex<float>,
ops::ReshapeGradKernel, plat::complex<double>,
ops::ReshapeGradKernel);
#endif
......@@ -436,9 +436,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
......@@ -446,9 +446,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -458,9 +458,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
......@@ -471,6 +471,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -167,18 +167,18 @@ REGISTER_OP_CPU_KERNEL(
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(trace)
......
......@@ -64,9 +64,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
......@@ -75,6 +75,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -341,17 +341,17 @@ REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
......@@ -366,9 +366,9 @@ REGISTER_OP_CPU_KERNEL(
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
......@@ -376,6 +376,6 @@ REGISTER_OP_CPU_KERNEL(
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -732,9 +732,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -742,9 +742,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose2,
......@@ -754,9 +754,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
......@@ -766,6 +766,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册