未验证 提交 6c399d94 编写于 作者: C chentianyu03 提交者: GitHub

Modify Ops from complex64/128 to complex<float/double> types. (#33133)

* modify kron OP to complex template types

* modify reshape, slice, trace, transpose OPs to complex template types

* modify to complex template types in eigen slice files

* change to complex template types for pad.cc and pac.cu

* format code style
上级 6a5b7e59
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -56,8 +55,8 @@ INSTANTIATION(EigenPad, int); ...@@ -56,8 +55,8 @@ INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, float);
INSTANTIATION(EigenPad, double); INSTANTIATION(EigenPad, double);
INSTANTIATION(EigenPad, platform::complex64); INSTANTIATION(EigenPad, platform::complex<float>);
INSTANTIATION(EigenPad, platform::complex128); INSTANTIATION(EigenPad, platform::complex<double>);
#undef INSTANTIATION #undef INSTANTIATION
} // namespace operators } // namespace operators
......
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -58,8 +57,8 @@ INSTANTIATION(EigenPad, int64_t); ...@@ -58,8 +57,8 @@ INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float); INSTANTIATION(EigenPad, float);
INSTANTIATION(EigenPad, double); INSTANTIATION(EigenPad, double);
INSTANTIATION(EigenPad, platform::float16); INSTANTIATION(EigenPad, platform::float16);
INSTANTIATION(EigenPad, platform::complex64); INSTANTIATION(EigenPad, platform::complex<float>);
INSTANTIATION(EigenPad, platform::complex128); INSTANTIATION(EigenPad, platform::complex<double>);
#undef INSTANTIATION #undef INSTANTIATION
} // namespace operators } // namespace operators
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -69,8 +67,6 @@ INSTANTIATION(EigenSlice, float); ...@@ -69,8 +67,6 @@ INSTANTIATION(EigenSlice, float);
INSTANTIATION(EigenSlice, double); INSTANTIATION(EigenSlice, double);
INSTANTIATION(EigenSlice, platform::float16); INSTANTIATION(EigenSlice, platform::float16);
INSTANTIATION(EigenSlice, platform::bfloat16); INSTANTIATION(EigenSlice, platform::bfloat16);
INSTANTIATION(EigenSlice, platform::complex64);
INSTANTIATION(EigenSlice, platform::complex128);
INSTANTIATION(EigenSlice, platform::complex<float>); INSTANTIATION(EigenSlice, platform::complex<float>);
INSTANTIATION(EigenSlice, platform::complex<double>); INSTANTIATION(EigenSlice, platform::complex<double>);
#undef INSTANTIATION #undef INSTANTIATION
......
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -58,8 +57,8 @@ INSTANTIATION(EigenSlice, int64_t); ...@@ -58,8 +57,8 @@ INSTANTIATION(EigenSlice, int64_t);
INSTANTIATION(EigenSlice, float); INSTANTIATION(EigenSlice, float);
INSTANTIATION(EigenSlice, double); INSTANTIATION(EigenSlice, double);
INSTANTIATION(EigenSlice, platform::float16); INSTANTIATION(EigenSlice, platform::float16);
INSTANTIATION(EigenSlice, platform::complex64); INSTANTIATION(EigenSlice, platform::complex<float>);
INSTANTIATION(EigenSlice, platform::complex128); INSTANTIATION(EigenSlice, platform::complex<double>);
#undef INSTANTIATION #undef INSTANTIATION
} // namespace operators } // namespace operators
......
...@@ -18,8 +18,7 @@ limitations under the License. */ ...@@ -18,8 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/kron_op.h" #include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -185,9 +184,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -185,9 +184,9 @@ REGISTER_OP_CPU_KERNEL(
ops::KronKernel<paddle::platform::CPUDeviceContext, int>, ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CPUDeviceContext, ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CPUDeviceContext, ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp); REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -198,6 +197,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -198,6 +197,6 @@ REGISTER_OP_CPU_KERNEL(
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>, ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/kron_op.h" #include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -26,9 +25,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -26,9 +25,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronKernel<paddle::platform::CUDADeviceContext, int>, ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CUDADeviceContext, ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CUDADeviceContext, ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>, kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -38,6 +37,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -38,6 +37,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>, ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -26,9 +26,6 @@ limitations under the License. */ ...@@ -26,9 +26,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
// Process an element in the output, used with a parallel-for // Process an element in the output, used with a parallel-for
template <typename T> template <typename T>
struct KronElemFunctor { struct KronElemFunctor {
...@@ -175,72 +172,13 @@ struct KronGradElemFunctor { ...@@ -175,72 +172,13 @@ struct KronGradElemFunctor {
const int ndims_; const int ndims_;
}; };
template <> template <typename T>
struct KronGradElemFunctor<complex64> { struct KronGradElemFunctor<platform::complex<T>> {
KronGradElemFunctor(const complex64* dout, const complex64* A, KronGradElemFunctor(const platform::complex<T>* dout,
const complex64* B, complex64* dout_a, complex64* dout_b, const platform::complex<T>* A,
const int64_t* stride_dout, const int64_t* stride_a, const platform::complex<T>* B,
const int64_t* stride_b, const int64_t* shape_b, platform::complex<T>* dout_a,
const int64_t numel_a, const int64_t numel_b, platform::complex<T>* dout_b, const int64_t* stride_dout,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex64* dout_;
const complex64* A_;
const complex64* B_;
complex64* dout_a_;
complex64* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex128> {
KronGradElemFunctor(const complex128* dout, const complex128* A,
const complex128* B, complex128* dout_a,
complex128* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b, const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a, const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims) const int64_t numel_b, const int ndims)
...@@ -273,21 +211,23 @@ struct KronGradElemFunctor<complex128> { ...@@ -273,21 +211,23 @@ struct KronGradElemFunctor<complex128> {
if (dout_a_) { if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b; size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] = dout_a_[index_out_a] =
dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag); dout_[idx] *
platform::complex<T>(B_[index_b].real, -B_[index_b].imag);
} }
if (dout_b_) { if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a; size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] = dout_b_[index_out_b] =
dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag); dout_[idx] *
platform::complex<T>(A_[index_a].real, -A_[index_a].imag);
} }
} }
private: private:
const complex128* dout_; const platform::complex<T>* dout_;
const complex128* A_; const platform::complex<T>* A_;
const complex128* B_; const platform::complex<T>* B_;
complex128* dout_a_; platform::complex<T>* dout_a_;
complex128* dout_b_; platform::complex<T>* dout_b_;
const int64_t* stride_dout_; const int64_t* stride_dout_;
const int64_t* stride_a_; const int64_t* stride_a_;
const int64_t* stride_b_; const int64_t* stride_b_;
......
...@@ -613,23 +613,24 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR( ...@@ -613,23 +613,24 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t, reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t,
ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel, int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel,
paddle::platform::bfloat16, ops::ReshapeKernel, paddle::platform::complex64, paddle::platform::bfloat16, ops::ReshapeKernel,
ops::ReshapeKernel, paddle::platform::complex128, ops::ReshapeKernel); paddle::platform::complex<float>, ops::ReshapeKernel,
paddle::platform::complex<double>, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR( REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double, reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel, ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel,
paddle::platform::complex64, ops::ReshapeGradKernel, paddle::platform::complex<float>, ops::ReshapeGradKernel,
paddle::platform::complex128, ops::ReshapeGradKernel); paddle::platform::complex<double>, ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR( REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16, ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16,
ops::ReshapeDoubleGradKernel, paddle::platform::complex64, ops::ReshapeDoubleGradKernel, paddle::platform::complex<float>,
ops::ReshapeDoubleGradKernel, paddle::platform::complex128, ops::ReshapeDoubleGradKernel, paddle::platform::complex<double>,
ops::ReshapeDoubleGradKernel); ops::ReshapeDoubleGradKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -650,22 +651,23 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ...@@ -650,22 +651,23 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
uint8_t, ops::ReshapeKernel, int64_t, uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel, ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel, plat::complex<float>, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel); plat::complex<double>, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR( REGISTER_OP_CUDA_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double, reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex64, ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex<float>,
ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel); ops::ReshapeGradKernel, plat::complex<double>, ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR( REGISTER_OP_CUDA_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel,
plat::float16, ops::ReshapeDoubleGradKernel, bool, plat::float16, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, plat::complex64, ops::ReshapeDoubleGradKernel, ops::ReshapeDoubleGradKernel, plat::complex<float>,
plat::complex128, ops::ReshapeDoubleGradKernel); ops::ReshapeDoubleGradKernel, plat::complex<double>,
ops::ReshapeDoubleGradKernel);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
...@@ -673,14 +675,14 @@ REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ...@@ -673,14 +675,14 @@ REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16, int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel, ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel, plat::complex<float>, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel); plat::complex<double>, ops::ReshapeKernel);
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, plat::complex64, ops::ReshapeGradKernel, plat::complex<float>,
ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel, plat::complex<double>,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
#endif #endif
...@@ -436,9 +436,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -436,9 +436,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>, ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>, ops::SliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>, slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
...@@ -446,9 +446,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -446,9 +446,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>, slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -458,9 +458,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -458,9 +458,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext, ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice_grad, slice_grad,
...@@ -471,6 +471,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -471,6 +471,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -167,18 +167,18 @@ REGISTER_OP_CPU_KERNEL( ...@@ -167,18 +167,18 @@ REGISTER_OP_CPU_KERNEL(
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>, ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>, trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(trace) REGISTER_OP_VERSION(trace)
......
...@@ -64,9 +64,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -64,9 +64,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>, trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
...@@ -75,6 +75,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -75,6 +75,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -341,17 +341,17 @@ REGISTER_OP_CPU_KERNEL( ...@@ -341,17 +341,17 @@ REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>, transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>, ops::Transpose2GradMaker<paddle::framework::OpDesc>,
...@@ -366,9 +366,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -366,9 +366,9 @@ REGISTER_OP_CPU_KERNEL(
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
...@@ -376,6 +376,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -376,6 +376,6 @@ REGISTER_OP_CPU_KERNEL(
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -732,9 +732,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -732,9 +732,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -742,9 +742,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -742,9 +742,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>, plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
...@@ -754,9 +754,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -754,9 +754,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
...@@ -766,6 +766,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -766,6 +766,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>, plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册