未验证 提交 34069c46 编写于 作者: Z zhangyuqin1998 提交者: GitHub

rename_bilinear_tensor_product (#52375)

* rename_bilinear_tensor_product

* fix
上级 a043d361
...@@ -92,7 +92,7 @@ namespace ops = paddle::operators; ...@@ -92,7 +92,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(bilinear_tensor_product, DECLARE_INFER_SHAPE_FUNCTOR(bilinear_tensor_product,
BilinearTensorProductInferShapeFunctor, BilinearTensorProductInferShapeFunctor,
PD_INFER_META(phi::BilinearTensorProductInferMeta)); PD_INFER_META(phi::BilinearInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR( DECLARE_INFER_SHAPE_FUNCTOR(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
BilinearTensorProductGradInferShapeFunctor, BilinearTensorProductGradInferShapeFunctor,
......
...@@ -152,7 +152,7 @@ ...@@ -152,7 +152,7 @@
infer_meta : infer_meta :
func : BilinearTensorProductGradInferMeta func : BilinearTensorProductGradInferMeta
kernel : kernel :
func : bilinear_tensor_product_grad func : bilinear_grad
- backward_op : cast_grad - backward_op : cast_grad
forward : cast (Tensor x, DataType dtype) -> Tensor(out) forward : cast (Tensor x, DataType dtype) -> Tensor(out)
......
...@@ -227,9 +227,9 @@ ...@@ -227,9 +227,9 @@
args : (Tensor x, Tensor y, Tensor weight, Tensor bias) args : (Tensor x, Tensor y, Tensor weight, Tensor bias)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : BilinearTensorProductInferMeta func : BilinearInferMeta
kernel : kernel :
func : bilinear_tensor_product func : bilinear
optional : bias optional : bias
backward : bilinear_tensor_product_grad backward : bilinear_tensor_product_grad
......
...@@ -695,7 +695,7 @@ void BatchNormInferInferMeta(const MetaTensor& x, ...@@ -695,7 +695,7 @@ void BatchNormInferInferMeta(const MetaTensor& x,
config); config);
} }
void BilinearTensorProductInferMeta(const MetaTensor& x, void BilinearInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& weight, const MetaTensor& weight,
const MetaTensor& bias, const MetaTensor& bias,
......
...@@ -198,7 +198,7 @@ void BatchNormInferInferMeta(const MetaTensor& x, ...@@ -198,7 +198,7 @@ void BatchNormInferInferMeta(const MetaTensor& x,
MetaTensor* variance_out, MetaTensor* variance_out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void BilinearTensorProductInferMeta(const MetaTensor& x, void BilinearInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& weight, const MetaTensor& weight,
const MetaTensor& bias, const MetaTensor& bias,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BilinearTensorProductGradKernel(const Context& dev_ctx, void BilinearGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& weight, const DenseTensor& weight,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BilinearTensorProductKernel(const Context& dev_ctx, void BilinearKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& weight, const DenseTensor& weight,
......
...@@ -12,14 +12,10 @@ ...@@ -12,14 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/bilinear_tensor_product_kernel.h" #include "paddle/phi/kernels/bilinear_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h" #include "paddle/phi/kernels/impl/bilinear_grad_kernel_impl.h"
PD_REGISTER_KERNEL(bilinear_tensor_product, PD_REGISTER_KERNEL(
GPU, bilinear_grad, CPU, ALL_LAYOUT, phi::BilinearGradKernel, float, double) {}
ALL_LAYOUT,
phi::BilinearTensorProductKernel,
float,
double) {}
...@@ -12,14 +12,10 @@ ...@@ -12,14 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/bilinear_tensor_product_kernel.h" #include "paddle/phi/kernels/bilinear_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/bilinear_tensor_product_kernel_impl.h" #include "paddle/phi/kernels/impl/bilinear_kernel_impl.h"
PD_REGISTER_KERNEL(bilinear_tensor_product, PD_REGISTER_KERNEL(
CPU, bilinear, CPU, ALL_LAYOUT, phi::BilinearKernel, float, double) {}
ALL_LAYOUT,
phi::BilinearTensorProductKernel,
float,
double) {}
...@@ -12,14 +12,10 @@ ...@@ -12,14 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h" #include "paddle/phi/kernels/bilinear_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/bilinear_grad_kernel_impl.h"
PD_REGISTER_KERNEL(bilinear_tensor_product_grad, PD_REGISTER_KERNEL(
GPU, bilinear_grad, GPU, ALL_LAYOUT, phi::BilinearGradKernel, float, double) {}
ALL_LAYOUT,
phi::BilinearTensorProductGradKernel,
float,
double) {}
...@@ -12,14 +12,10 @@ ...@@ -12,14 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/bilinear_tensor_product_grad_kernel.h" #include "paddle/phi/kernels/bilinear_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/bilinear_tensor_product_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/bilinear_kernel_impl.h"
PD_REGISTER_KERNEL(bilinear_tensor_product_grad, PD_REGISTER_KERNEL(
CPU, bilinear, GPU, ALL_LAYOUT, phi::BilinearKernel, float, double) {}
ALL_LAYOUT,
phi::BilinearTensorProductGradKernel,
float,
double) {}
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BilinearTensorProductGradKernel(const Context& ctx, void BilinearGradKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& weight, const DenseTensor& weight,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BilinearTensorProductKernel(const Context& ctx, void BilinearKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& weight, const DenseTensor& weight,
......
...@@ -18,13 +18,12 @@ namespace phi { ...@@ -18,13 +18,12 @@ namespace phi {
KernelSignature BilinearTensorProductOpArgumentMapping( KernelSignature BilinearTensorProductOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("bilinear", {"X", "Y", "Weight", "Bias"}, {}, {"Out"});
"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}, {}, {"Out"});
} }
KernelSignature BilinearTensorProductGradOpArgumentMapping( KernelSignature BilinearTensorProductGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_tensor_product_grad", return KernelSignature("bilinear_grad",
{"X", "Y", "Weight", "Out@GRAD"}, {"X", "Y", "Weight", "Out@GRAD"},
{}, {},
{"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"}); {"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册