未验证 提交 e0dd7f32 编写于 作者: X xiongkun 提交者: GitHub

Einsum grad complex (#44598)

* add complex for einsum grad kernel

* pass the ci
上级 25d3dce1
......@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h"
PD_REGISTER_KERNEL(
einsum_grad, CPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {}
PD_REGISTER_KERNEL(einsum_grad,
CPU,
ALL_LAYOUT,
phi::EinsumGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -18,6 +18,14 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_kernel_impl.h"
PD_REGISTER_KERNEL(
tile, CPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) {
}
PD_REGISTER_KERNEL(tile,
CPU,
ALL_LAYOUT,
phi::TileKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......@@ -75,6 +76,8 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, dtype::complex<float>);
INSTANTIATION(EigenBroadcast, dtype::complex<double>);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
......@@ -82,6 +85,8 @@ INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, dtype::float16);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<float>);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<double>);
INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, int);
INSTANTIATION(EigenBroadcastGrad, int64_t);
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......@@ -77,12 +78,16 @@ INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, dtype::complex<float>);
INSTANTIATION(EigenBroadcast, dtype::complex<double>);
INSTANTIATION(EigenBroadcast, int);
INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, dtype::float16);
INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<float>);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<double>);
INSTANTIATION(EigenBroadcastGrad, int);
INSTANTIATION(EigenBroadcastGrad, int64_t);
template struct EigenBroadcastGrad<Eigen::GpuDevice, float, 0>;
......
......@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(einsum_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -28,4 +28,6 @@ PD_REGISTER_KERNEL(tile,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册