未验证 提交 60e0c506 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Standardise some C++ API (#47385)

* standard api

* fix ci bugs

* fix ci bugs

* fix ce bugs
上级 6e1c14e3
......@@ -261,6 +261,11 @@ struct BaseActivationFunctor {
template <typename T> \
using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor<T>;
template <typename T>
using BReluFunctor = phi::funcs::HardTanhFunctor<T>;
template <typename T>
using BReluGradFunctor = phi::funcs::HardTanhGradFunctor<T>;
USE_PHI_FUNCTOR(Cos)
USE_PHI_FUNCTOR(Tan)
USE_PHI_FUNCTOR(Acos)
......@@ -275,7 +280,6 @@ USE_PHI_FUNCTOR(Atanh)
USE_PHI_FUNCTOR(Tanh)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh)
USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR(ThresholdedRelu)
USE_PHI_FUNCTOR(Relu6)
USE_PHI_FUNCTOR(LeakyRelu)
......
......@@ -130,6 +130,11 @@ class ActivationGradCudaKernel
}
};
template <typename T>
using CudaBReluFunctor = phi::funcs::CudaHardTanhFunctor<T>;
template <typename T>
using CudaBReluGradFunctor = phi::funcs::CudaHardTanhGradFunctor<T>;
USE_PHI_FUNCTOR(CudaCos)
USE_PHI_FUNCTOR(CudaTan)
USE_PHI_FUNCTOR(CudaAcos)
......@@ -142,7 +147,6 @@ USE_PHI_FUNCTOR(CudaAsinh)
USE_PHI_FUNCTOR(CudaAcosh)
USE_PHI_FUNCTOR(CudaAtanh)
USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
USE_PHI_FUNCTOR(CudaRelu6)
......@@ -276,13 +280,13 @@ REGISTER_OP_KERNEL(
KP,
plat::XPUPlace,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
phi::funcs::CudaBReluFunctor<float>>);
phi::funcs::CudaHardTanhFunctor<float>>);
REGISTER_OP_KERNEL(
brelu_grad,
KP,
plat::XPUPlace,
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
phi::funcs::CudaBReluGradFunctor<float>>);
phi::funcs::CudaHardTanhGradFunctor<float>>);
REGISTER_OP_KERNEL(ceil,
KP,
......
......@@ -75,8 +75,8 @@ class CropTensorOp : public framework::OperatorWithKernel {
x_dim.size()));
if (ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in
// CropTensorKernel with ExecutionContext. Also check LoD in
// CropTensorKernel.
// CropKernel with ExecutionContext. Also check LoD in
// CropKernel.
ctx->ShareLoD("X", /*->*/ "Out");
} else {
auto out_dims = std::vector<int>(shape_dim[0], -1);
......
......@@ -132,7 +132,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gaussian_random,
GaussianRandomInferShapeFunctor,
PD_INFER_META(phi::GaussianRandomInferMeta));
PD_INFER_META(phi::GaussianInferMeta));
REGISTER_OPERATOR(
gaussian_random,
......
......@@ -127,7 +127,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(graph_send_recv,
GraphSendRecvInferShapeFunctor,
PD_INFER_META(phi::GraphSendRecvInferMeta));
PD_INFER_META(phi::SendURecvInferMeta));
REGISTER_OPERATOR(graph_send_recv,
ops::GraphSendRecvOP,
ops::GraphSendRecvOpMaker,
......
......@@ -140,7 +140,7 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(graph_send_ue_recv,
GraphSendUERecvInferShapeFunctor,
PD_INFER_META(phi::GraphSendUERecvInferMeta));
PD_INFER_META(phi::SendUERecvInferMeta));
REGISTER_OPERATOR(graph_send_ue_recv,
ops::GraphSendUERecvOP,
ops::GraphSendUERecvOpMaker,
......
......@@ -295,17 +295,6 @@
output : Tensor(x_grad)
invoke : flip(out_grad, axis)
- backward_op : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_uv_grad
data_type : x
- backward_op : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -336,6 +325,17 @@
kernel :
func : poisson_grad
- backward_op : send_uv_grad
forward : send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : send_uv_grad
data_type : x
- backward_op : sin_grad
forward : sin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -59,7 +59,7 @@
inplace : (grad_grad_out_grad -> grad_grad_x_grad)
- backward_op : addmm_grad
forward : addmm (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out)
forward : addmm (Tensor input, Tensor x, Tensor y, float beta, float alpha) -> Tensor(out)
args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha, float beta)
output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad)
infer_meta :
......@@ -198,17 +198,6 @@
kernel :
func : bilinear_tensor_product_grad
- backward_op : brelu_grad
forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float t_min, float t_max)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : brelu_grad
inplace : (out_grad -> x_grad)
- backward_op : broadcast_tensors_grad
forward : broadcast_tensors (Tensor[] input) -> Tensor[](out)
args : (Tensor[] input, Tensor[] out_grad)
......@@ -401,14 +390,14 @@
func : conv3d_transpose_grad
use_gpudnn : true
- backward_op : crop_tensor_grad
- backward_op : crop_grad
forward : crop_tensor (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray offsets)
output : Tensor(x_grad)
infer_meta :
func : CropTensorGradInferMeta
func : CropGradInferMeta
kernel :
func : crop_tensor_grad
func : crop_grad
data_type : x
- backward_op : cross_entropy_with_softmax_grad
......@@ -779,30 +768,6 @@
kernel :
func : gelu_grad
- backward_op : graph_send_recv_grad
forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM")
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : graph_send_recv_grad
data_type : out_grad
optional: out, dst_count
- backward_op : graph_send_ue_recv_grad
forward : graph_send_ue_recv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str message_op, str reduce_op)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_ue_recv_grad
data_type : out_grad
optional: out, dst_count
- backward_op : grid_sample_grad
forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out)
args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners)
......@@ -870,6 +835,17 @@
func : hard_swish_grad
inplace : (out_grad -> x_grad)
- backward_op : hardtanh_grad
forward : hardtanh (Tensor x, float t_min, float t_max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float t_min, float t_max)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hard_tanh_grad
inplace : (out_grad -> x_grad)
- backward_op : hierarchical_sigmoid_grad
forward : hierarchical_sigmoid (Tensor x, Tensor w, Tensor label, Tensor path, Tensor code, Tensor bias, int num_classes, bool remote_prefetch, int trainer_id, int64_t[] height_sections, str[] epmap, str[] table_names, bool is_sparse) -> Tensor(out), Tensor(pre_out), Tensor(w_out)
args : (Tensor x, Tensor w, Tensor label, Tensor path, Tensor code, Tensor bias, Tensor pre_out, Tensor out_grad, int num_classes, bool remote_prefetch, int trainer_id, int64_t[] height_sections, str[] epmap, str[] table_names, bool is_sparse)
......@@ -1624,12 +1600,12 @@
# output is optional
- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor index, Tensor value, int axis, str reduce) -> Tensor(out)
args : (Tensor arr, Tensor index, Tensor out_grad, int axis, str reduce)
forward : put_along_axis (Tensor arr, Tensor indices, Tensor value, int axis, str reduce) -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
output : Tensor(arr_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [arr, index]
param : [arr, indices]
kernel :
func : put_along_axis_grad
......@@ -1911,6 +1887,30 @@
kernel :
func : selu_grad
- backward_op : send_u_recv_grad
forward : send_u_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM")
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : send_u_recv_grad
data_type : out_grad
optional: out, dst_count
- backward_op : send_ue_recv_grad
forward : send_ue_recv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str message_op, str reduce_op)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : send_ue_recv_grad
data_type : out_grad
optional: out, dst_count
- backward_op : sigmoid_cross_entropy_with_logits_grad
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, bool normalize, int ignore_index) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, bool normalize, int ignore_index)
......
......@@ -88,7 +88,7 @@
backward : add_n_grad
- op : addmm
args : (Tensor input, Tensor x, Tensor y, float alpha, float beta)
args : (Tensor input, Tensor x, Tensor y, float beta, float alpha)
output : Tensor
infer_meta :
func : AddmmInferMeta
......@@ -346,16 +346,6 @@
func : box_coder
optional : prior_box_var
- op : brelu
args : (Tensor x, float t_min, float t_max)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : brelu
backward : brelu_grad
- op : cast
args : (Tensor x, DataType dtype)
output : Tensor
......@@ -508,15 +498,15 @@
output : Tensor(out)
invoke : copy_to_impl(x, place, blocking)
- op : crop_tensor
- op : crop
args : (Tensor x, IntArray shape, IntArray offsets)
output : Tensor(out)
infer_meta :
func : CropTensorInferMeta
func : CropInferMeta
kernel :
func : crop_tensor
func : crop
data_type : x
backward : crop_tensor_grad
backward : crop_grad
# Part of python API paddle.nn.functional.cross_entropy
- op : cross_entropy_with_softmax
......@@ -979,14 +969,14 @@
kernel :
func : gather_tree
- op : gaussian_random
- op : gaussian
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor(out)
infer_meta :
func : GaussianRandomInferMeta
func : GaussianInferMeta
param : [shape, mean, std, seed, dtype]
kernel :
func : gaussian_random
func : gaussian
param : [shape, mean, std, seed, dtype]
data_type : dtype
backend : place
......@@ -1009,28 +999,6 @@
kernel :
func : generate_proposals_v2
- op : graph_send_recv
args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0})
output : Tensor(out), Tensor(dst_count)
infer_meta :
func : GraphSendRecvInferMeta
kernel :
func : graph_send_recv
data_type : x
intermediate : dst_count
backward : graph_send_recv_grad
- op : graph_send_ue_recv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size)
output : Tensor(out), Tensor(dst_count)
infer_meta :
func : GraphSendUERecvInferMeta
kernel :
func : graph_send_ue_recv
data_type : x
intermediate : dst_count
backward : graph_send_ue_recv_grad
- op : greater_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor(out)
......@@ -1108,6 +1076,16 @@
func : hard_swish
backward : hardswish_grad
- op : hardtanh
args : (Tensor x, float t_min, float t_max)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hard_tanh
backward : hardtanh_grad
- op : hierarchical_sigmoid
args : (Tensor x, Tensor w, Tensor label, Tensor path, Tensor code, Tensor bias, int num_classes, bool remote_prefetch, int trainer_id, int64_t[] height_sections, str[] epmap, str[] table_names, bool is_sparse)
output : Tensor(out), Tensor(pre_out), Tensor(w_out)
......@@ -1958,7 +1936,7 @@
backward : psroi_pool_grad
- op : put_along_axis
args : (Tensor arr, Tensor index, Tensor value, int axis, str reduce)
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
......@@ -2234,6 +2212,28 @@
func : selu
backward : selu_grad
- op : send_u_recv
args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0})
output : Tensor(out), Tensor(dst_count)
infer_meta :
func : SendURecvInferMeta
kernel :
func : send_u_recv
data_type : x
intermediate : dst_count
backward : send_u_recv_grad
- op : send_ue_recv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size)
output : Tensor(out), Tensor(dst_count)
infer_meta :
func : SendUERecvInferMeta
kernel :
func : send_ue_recv
data_type : x
intermediate : dst_count
backward : send_ue_recv_grad
- op : sgd_
args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision)
output : Tensor(param_out), Tensor(master_param_out)
......
......@@ -704,6 +704,9 @@
extra :
attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]
- op : send_uv (graph_send_uv)
backward : send_uv_grad (graph_send_uv_grad)
- op : sequence_softmax
backward : sequence_softmax_grad
extra :
......
......@@ -262,16 +262,6 @@
func : flip
backward : flip_grad
- op : graph_send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
infer_meta :
func : GraphSendUVInferMeta
kernel :
func : graph_send_uv
data_type : x
backward : graph_send_uv_grad
- op : lgamma
args : (Tensor x)
output : Tensor(out)
......@@ -299,6 +289,16 @@
func : poisson
backward : poisson_grad
- op : send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
infer_meta :
func : SendUVInferMeta
kernel :
func : send_uv
data_type : x
backward : send_uv_grad
- op : sin
args : (Tensor x)
output : Tensor
......
......@@ -44,7 +44,7 @@
add_coo_dense_grad{sparse_coo, dense, sparse_coo -> sparse_coo, dense}
- backward_op : addmm_grad
forward : addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out)
forward : addmm(Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0) -> Tensor(out)
args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha=1.0, float beta=1.0)
output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad)
infer_meta :
......
......@@ -224,6 +224,17 @@
layout : x
backward : relu6_grad
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
- op : scale
args : (Tensor x, float scale, float bias, bool bias_after_scale)
output : Tensor(out)
......@@ -312,6 +323,17 @@
layout : x
backward : subtract_grad
- op : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
kernel :
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
- op : tan
args : (Tensor x)
output : Tensor(out)
......@@ -364,6 +386,18 @@
func : dense_to_csr {dense -> sparse_csr},
coo_to_csr {sparse_coo -> sparse_csr}
- op : transpose
args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
param: [ x, perm ]
kernel :
func : transpose_coo{sparse_coo -> sparse_coo},
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad
- op : values
args : (Tensor x)
output : Tensor(out)
......@@ -376,7 +410,7 @@
backward : values_grad
- op: addmm
args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0)
args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
......@@ -469,37 +503,3 @@
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad
- op : transpose
args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
param: [ x, perm ]
kernel :
func : transpose_coo{sparse_coo -> sparse_coo},
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad
- op : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
kernel :
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
......@@ -81,7 +81,9 @@ static const std::unordered_set<std::string> deprecated_op_names(
"nearest_interp",
"nearest_interp_grad",
"bicubic_interp",
"bicubic_interp_grad"});
"bicubic_interp_grad",
"crop",
"crop_grad"});
class DefaultKernelSignatureMap {
public:
......
......@@ -186,10 +186,10 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x,
}
}
void CropTensorGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x,
const IntArray& offsets,
MetaTensor* x_grad) {
void CropGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x,
const IntArray& offsets,
MetaTensor* x_grad) {
auto x_dims = x.dims();
if (x_grad != nullptr) {
......
......@@ -107,10 +107,10 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x,
MetaTensor* dfilter,
MetaTensor* ddout);
void CropTensorGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x,
const IntArray& offsets,
MetaTensor* x_grad);
void CropGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x,
const IntArray& offsets,
MetaTensor* x_grad);
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax,
......
......@@ -2455,6 +2455,164 @@ void SgdInferMeta(const MetaTensor& param,
param_out->set_dtype(param.dtype());
}
void SendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
y_dims[0],
src_index_dims[0],
phi::errors::InvalidArgument(
"Expect Input Y to have size %d as Src_index on the first dimension, "
"but we get %d",
src_index_dims[0],
y_dims[0]));
auto x_dims = x.dims();
if (reduce_op == "MEAN") {
dst_count->set_dims({-1});
dst_count->set_dtype(DataType::INT32);
}
// Infer out's shape according to x and e(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), -1);
out->set_dims(phi::make_ddim(out_dims_array));
}
void SendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]);
out->set_dims(phi::make_ddim(out_dims_array));
}
void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out,
......@@ -2751,164 +2909,6 @@ void Yolov3LossInferMeta(const MetaTensor& x,
gt_match_mask->set_dtype(x.dtype());
}
void GraphSendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
y_dims[0],
src_index_dims[0],
phi::errors::InvalidArgument(
"Expect Input Y to have size %d as Src_index on the first dimension, "
"but we get %d",
src_index_dims[0],
y_dims[0]));
auto x_dims = x.dims();
if (reduce_op == "MEAN") {
dst_count->set_dims({-1});
dst_count->set_dtype(DataType::INT32);
}
// Infer out's shape according to x and e(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), -1);
out->set_dims(phi::make_ddim(out_dims_array));
}
void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]);
out->set_dims(phi::make_ddim(out_dims_array));
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -451,6 +451,23 @@ void RnnInferMeta(const MetaTensor& x,
std::vector<MetaTensor*> state,
MetaTensor* reserve);
void SendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);
void SendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out);
void SgdInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& grad,
......@@ -506,21 +523,4 @@ void Yolov3LossInferMeta(const MetaTensor& x,
MetaTensor* objectness_mask,
MetaTensor* gt_match_mask);
void GraphSendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);
void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out);
} // namespace phi
......@@ -73,12 +73,12 @@ void EyeInferMeta(const Scalar& num_rows,
out->set_dtype(dtype);
}
void GaussianRandomInferMeta(const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out) {
void GaussianInferMeta(const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape.GetData());
out->set_dims(out_dims);
out->set_dtype(dtype);
......
......@@ -48,12 +48,12 @@ void EyeInferMeta(const Scalar& num_rows,
MetaTensor* out,
MetaConfig config = MetaConfig());
void GaussianRandomInferMeta(const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out);
void GaussianInferMeta(const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
MetaTensor* out);
void RandpermInferMeta(int n, DataType dtype, MetaTensor* out);
......
......@@ -78,8 +78,8 @@ void AccuracyInferMeta(const MetaTensor& out,
void AddmmInferMeta(const MetaTensor& input,
const MetaTensor& x,
const MetaTensor& y,
float alpha,
float beta,
float alpha,
MetaTensor* out) {
auto input_dims = input.dims();
auto x_dims = x.dims();
......@@ -402,13 +402,13 @@ void InstanceNormInferMeta(const MetaTensor& x,
}
}
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
void SendURecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
......
......@@ -44,8 +44,8 @@ void AccuracyInferMeta(const MetaTensor& out,
void AddmmInferMeta(const MetaTensor& input,
const MetaTensor& x,
const MetaTensor& y,
float alpha,
float beta,
float alpha,
MetaTensor* out);
void ArangeInferMeta(const MetaTensor& start,
......@@ -72,13 +72,13 @@ void InstanceNormInferMeta(const MetaTensor& x,
MetaTensor* saved_variance,
MetaConfig config = MetaConfig());
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);
void SendURecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);
void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
......
......@@ -436,11 +436,11 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
CumInferMeta(x, axis.to<int>(), flatten, exclusive, reverse, out);
}
void CropTensorInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
MetaTensor* out,
MetaConfig config) {
void CropInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NE(
out,
nullptr,
......
......@@ -82,11 +82,11 @@ void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out);
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
void CropTensorInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
MetaTensor* out,
MetaConfig config = MetaConfig());
void CropInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
MetaTensor* out,
MetaConfig config = MetaConfig());
void CumInferMeta(const MetaTensor& x,
int axis,
......
......@@ -243,7 +243,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);
......
......@@ -85,7 +85,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardTanh, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
......
......@@ -23,8 +23,8 @@ void AddmmKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out);
} // namespace phi
......@@ -173,8 +173,8 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CELUGradFunctor, alpha);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
BReluGradFunctor,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh,
HardTanhGradFunctor,
t_min,
t_max);
......@@ -263,7 +263,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_tanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel)
......
......@@ -104,7 +104,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh, HardTanhFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
......@@ -146,7 +146,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
......
......@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/crop_tensor_kernel.h"
#include "paddle/phi/kernels/crop_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/crop_tensor_kernel_impl.h"
#include "paddle/phi/kernels/impl/crop_grad_kernel_impl.h"
PD_REGISTER_KERNEL(crop_tensor,
PD_REGISTER_KERNEL(crop_grad,
CPU,
ALL_LAYOUT,
phi::CropTensorKernel,
phi::CropGradKernel,
float,
double,
int,
......
......@@ -12,17 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/crop_tensor_grad_kernel.h"
#include "paddle/phi/kernels/crop_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/crop_kernel_impl.h"
PD_REGISTER_KERNEL(crop_tensor_grad,
CPU,
ALL_LAYOUT,
phi::CropTensorGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
crop, CPU, ALL_LAYOUT, phi::CropKernel, float, double, int, int64_t) {}
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/phi/kernels/gaussian_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
......@@ -21,13 +21,13 @@
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& dev_ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
void GaussianKernel(const Context& dev_ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
auto tensor = out;
std::normal_distribution<T> dist(mean, std);
......@@ -44,9 +44,5 @@ void GaussianRandomKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(gaussian_random,
CPU,
ALL_LAYOUT,
phi::GaussianRandomKernel,
float,
double) {}
PD_REGISTER_KERNEL(
gaussian, CPU, ALL_LAYOUT, phi::GaussianKernel, float, double) {}
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h"
#include "paddle/phi/kernels/send_u_recv_grad_kernel.h"
#include <algorithm>
#include <vector>
......@@ -117,15 +117,15 @@ void GraphSendRecvGradOpKernelLaunchHelper(
}
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& reduce_op,
DenseTensor* x_grad) {
void SendURecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& reduce_op,
DenseTensor* x_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
......@@ -154,10 +154,10 @@ void GraphSendRecvGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv_grad,
PD_REGISTER_KERNEL(send_u_recv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendRecvGradKernel,
phi::SendURecvGradKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_recv_kernel.h"
#include "paddle/phi/kernels/send_u_recv_kernel.h"
#include <algorithm>
#include <set>
......@@ -144,14 +144,14 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
void SendURecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
......@@ -177,10 +177,10 @@ void GraphSendRecvKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv,
PD_REGISTER_KERNEL(send_u_recv,
CPU,
ALL_LAYOUT,
phi::GraphSendRecvKernel,
phi::SendURecvKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h"
#include "paddle/phi/kernels/send_ue_recv_grad_kernel.h"
#include <algorithm>
#include <vector>
......@@ -443,18 +443,18 @@ void GraphSendUERecvGradOpKernelLaunchHelper(
}
template <typename T, typename Context>
void GraphSendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
void SendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int32_t>(
......@@ -489,10 +489,10 @@ void GraphSendUERecvGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv_grad,
PD_REGISTER_KERNEL(send_ue_recv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendUERecvGradKernel,
phi::SendUERecvGradKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h"
#include "paddle/phi/kernels/send_ue_recv_kernel.h"
#include <algorithm>
#include <set>
......@@ -244,16 +244,16 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
void SendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
......@@ -283,10 +283,10 @@ void GraphSendUERecvKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv,
PD_REGISTER_KERNEL(send_ue_recv,
CPU,
ALL_LAYOUT,
phi::GraphSendUERecvKernel,
phi::SendUERecvKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h"
#include "paddle/phi/kernels/send_uv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
......@@ -229,15 +229,15 @@ void GraphSendUVGradOpKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
void SendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpKernelLaunchHelper<Context, T, int32_t>(
......@@ -250,10 +250,10 @@ void GraphSendUVGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv_grad,
PD_REGISTER_KERNEL(send_uv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendUVGradKernel,
phi::SendUVGradKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_kernel.h"
#include "paddle/phi/kernels/send_uv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
......@@ -102,13 +102,13 @@ void GraphSendUVOpKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
void SendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVOpKernelLaunchHelper<Context, T, int32_t>(
......@@ -121,11 +121,5 @@ void GraphSendUVKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv,
CPU,
ALL_LAYOUT,
phi::GraphSendUVKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
send_uv, CPU, ALL_LAYOUT, phi::SendUVKernel, float, double, int, int64_t) {}
......@@ -20,10 +20,10 @@
namespace phi {
template <typename T, typename Context>
void CropTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
const IntArray& offsets,
DenseTensor* out);
void CropGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& offsets,
DenseTensor* x_grad);
} // namespace phi
......@@ -20,10 +20,10 @@
namespace phi {
template <typename T, typename Context>
void CropTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& offsets,
DenseTensor* x_grad);
void CropKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
const IntArray& offsets,
DenseTensor* out);
} // namespace phi
......@@ -956,7 +956,7 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor<T> {
};
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
struct HardTanhFunctor : public BaseActivationFunctor<T> {
float t_min;
float t_max;
......@@ -974,7 +974,7 @@ struct BReluFunctor : public BaseActivationFunctor<T> {
};
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
struct HardTanhGradFunctor : public BaseActivationFunctor<T> {
float t_min;
float t_max;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
......@@ -2707,7 +2707,7 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
};
template <typename T>
struct CudaBReluFunctor : public BaseActivationFunctor<T> {
struct CudaHardTanhFunctor : public BaseActivationFunctor<T> {
float t_min;
float t_max;
......@@ -2775,7 +2775,7 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
};
template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
struct CudaHardTanhGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float t_min;
float t_max;
......
......@@ -21,12 +21,12 @@
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out);
void GaussianKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out);
} // namespace phi
......@@ -228,8 +228,8 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6,
CudaRelu6GradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
CudaBReluGradFunctor,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh,
CudaHardTanhGradFunctor,
t_min,
t_max);
......@@ -346,7 +346,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_triple_grad, TanhTripleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_tanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
......
......@@ -122,7 +122,10 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh,
CudaHardTanhFunctor,
t_min,
t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus,
CudaSoftplusFunctor,
......@@ -193,7 +196,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
......
......@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/crop_tensor_kernel.h"
#include "paddle/phi/kernels/crop_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/crop_tensor_kernel_impl.h"
#include "paddle/phi/kernels/impl/crop_grad_kernel_impl.h"
PD_REGISTER_KERNEL(crop_tensor,
PD_REGISTER_KERNEL(crop_grad,
GPU,
ALL_LAYOUT,
phi::CropTensorKernel,
phi::CropGradKernel,
float,
double,
int,
......
......@@ -12,17 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/crop_tensor_grad_kernel.h"
#include "paddle/phi/kernels/crop_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/crop_tensor_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/crop_kernel_impl.h"
PD_REGISTER_KERNEL(crop_tensor_grad,
GPU,
ALL_LAYOUT,
phi::CropTensorGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
crop, GPU, ALL_LAYOUT, phi::CropKernel, float, double, int, int64_t) {}
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/phi/kernels/gaussian_kernel.h"
#include <thrust/random.h>
......@@ -52,13 +52,13 @@ struct GaussianGenerator {
};
template <typename T, typename Context>
void GaussianRandomKernel(const Context& dev_ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
void GaussianKernel(const Context& dev_ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
out->Resize(phi::make_ddim(shape.GetData()));
dev_ctx.template Alloc<T>(out);
if (seed == 0) {
......@@ -78,10 +78,10 @@ void GaussianRandomKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(gaussian_random,
PD_REGISTER_KERNEL(gaussian,
GPU,
ALL_LAYOUT,
phi::GaussianRandomKernel,
phi::GaussianKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
......
......@@ -22,7 +22,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/graph_send_recv_kernel.h"
#include "paddle/phi/kernels/send_u_recv_kernel.h"
namespace phi {
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h"
#include "paddle/phi/kernels/send_u_recv_grad_kernel.h"
#include <algorithm>
#include <vector>
......@@ -98,15 +98,15 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
}
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& reduce_op,
DenseTensor* x_grad) {
void SendURecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& reduce_op,
DenseTensor* x_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
......@@ -135,10 +135,10 @@ void GraphSendRecvGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv_grad,
PD_REGISTER_KERNEL(send_u_recv_grad,
GPU,
ALL_LAYOUT,
phi::GraphSendRecvGradKernel,
phi::SendURecvGradKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_recv_kernel.h"
#include "paddle/phi/kernels/send_u_recv_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/fill.h>
......@@ -154,14 +154,14 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
void SendURecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
......@@ -187,10 +187,10 @@ void GraphSendRecvKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv,
PD_REGISTER_KERNEL(send_u_recv,
GPU,
ALL_LAYOUT,
phi::GraphSendRecvKernel,
phi::SendURecvKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h"
#include "paddle/phi/kernels/send_ue_recv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -556,18 +556,18 @@ void GraphSendUERecvGradOpCUDAKernelLaunchHelper(
}
template <typename T, typename Context>
void GraphSendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
void SendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUERecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
......@@ -602,10 +602,10 @@ void GraphSendUERecvGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv_grad,
PD_REGISTER_KERNEL(send_ue_recv_grad,
GPU,
ALL_LAYOUT,
phi::GraphSendUERecvGradKernel,
phi::SendUERecvGradKernel,
float,
double,
int,
......
......@@ -12,10 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/send_ue_recv_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/fill.h>
......@@ -26,6 +23,9 @@
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
namespace phi {
......@@ -282,16 +282,16 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
void SendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
......@@ -323,10 +323,10 @@ void GraphSendUERecvKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv,
PD_REGISTER_KERNEL(send_ue_recv,
GPU,
ALL_LAYOUT,
phi::GraphSendUERecvKernel,
phi::SendUERecvKernel,
float,
double,
int,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h"
#include "paddle/phi/kernels/send_uv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
......@@ -285,15 +285,15 @@ void GraphSendUVGradOpCUDAKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
void SendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
......@@ -306,10 +306,10 @@ void GraphSendUVGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv_grad,
PD_REGISTER_KERNEL(send_uv_grad,
GPU,
ALL_LAYOUT,
phi::GraphSendUVGradKernel,
phi::SendUVGradKernel,
float,
double,
int,
......
......@@ -12,9 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_kernel.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/send_uv_kernel.h"
#include <thrust/device_vector.h>
......@@ -22,6 +20,8 @@
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
namespace phi {
......@@ -142,13 +142,13 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
}
template <typename T, typename Context>
void GraphSendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
void SendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVOpCUDAKernelLaunchHelper<Context, T, int32_t>(
......@@ -161,10 +161,10 @@ void GraphSendUVKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv,
PD_REGISTER_KERNEL(send_uv,
GPU,
ALL_LAYOUT,
phi::GraphSendUVKernel,
phi::SendUVKernel,
float,
double,
int,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
......@@ -37,8 +37,8 @@ void AddmmKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out) {
auto input_dims = input.dims();
auto x_dims = x.dims();
......
......@@ -15,7 +15,7 @@
#pragma once
#include "paddle/phi/kernels/crop_tensor_grad_kernel.h"
#include "paddle/phi/kernels/crop_grad_kernel.h"
#include <vector>
......@@ -52,11 +52,11 @@ void CropTensorGradFunction(const Context& dev_ctx,
}
template <typename T, typename Context>
void CropTensorGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const IntArray& offsets,
DenseTensor* x_grad) {
void CropGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const IntArray& offsets,
DenseTensor* x_grad) {
size_t rank = out_grad.dims().size();
PADDLE_ENFORCE_GE(
rank,
......
......@@ -14,7 +14,7 @@
#pragma once
#include "paddle/phi/kernels/crop_tensor_kernel.h"
#include "paddle/phi/kernels/crop_kernel.h"
#include <utility>
#include <vector>
......@@ -127,11 +127,11 @@ void CropTensorFunction(const Context& dev_ctx,
}
template <typename T, typename Context>
void CropTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
const IntArray& offsets,
DenseTensor* out) {
void CropKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
const IntArray& offsets,
DenseTensor* out) {
int rank = x.dims().size();
PADDLE_ENFORCE_GE(
rank,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/phi/kernels/gaussian_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -20,13 +20,13 @@
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
void GaussianKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
std::normal_distribution<T> dist(mean, std);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
......@@ -51,5 +51,4 @@ void GaussianRandomKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
gaussian_random, OneDNN, ONEDNN, phi::GaussianRandomKernel, float) {}
PD_REGISTER_KERNEL(gaussian, OneDNN, ONEDNN, phi::GaussianKernel, float) {}
......@@ -16,19 +16,19 @@
#include <string>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
void SendURecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count);
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SendURecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count);
} // namespace phi
......@@ -15,20 +15,22 @@
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& reduce_op,
DenseTensor* x_grad);
void SendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
......@@ -20,12 +20,12 @@
namespace phi {
template <typename T, typename Context>
void GraphSendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out);
void SendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out);
} // namespace phi
......@@ -27,8 +27,8 @@ void AddmmCooCooKernel(const Context& dev_ctx,
const SparseCooTensor& input,
const SparseCooTensor& x,
const SparseCooTensor& y,
float alpha,
float beta,
float alpha,
SparseCooTensor* out);
/* DENSE + COO @ DENSE -> DENSE */
......@@ -37,8 +37,8 @@ void AddmmCooDenseKernel(const Context& dev_ctx,
const DenseTensor& input,
const SparseCooTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out);
// TODO(zhouwei25): implement " CSR + CSR @ CSR -> CSR"
......@@ -47,8 +47,8 @@ void AddmmCsrCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& input,
const SparseCsrTensor& x,
const SparseCsrTensor& y,
float alpha,
float beta,
float alpha,
SparseCsrTensor* out);
/* DENSE + CSR @ DENSE -> DENSE */
......@@ -57,8 +57,8 @@ void AddmmCsrDenseKernel(const Context& dev_ctx,
const DenseTensor& input,
const SparseCsrTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out);
} // namespace sparse
......
......@@ -25,8 +25,8 @@ void AddmmCooDenseKernel(const Context& dev_ctx,
const DenseTensor& input,
const SparseCooTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out) {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support CPU kernel of 'sparse.addmm' now."));
......@@ -38,8 +38,8 @@ void AddmmCsrDenseKernel(const Context& dev_ctx,
const DenseTensor& input,
const SparseCsrTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out) {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support CPU kernel of 'sparse.addmm' now."));
......
......@@ -31,8 +31,8 @@ void AddmmKernelImpl(const Context& dev_ctx,
const DenseTensor& input,
const TensorType& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out) {
#if CUDA_VERSION >= 11000
std::vector<int64_t> input_dim = phi::vectorize(input.dims());
......@@ -107,10 +107,10 @@ void AddmmCooDenseKernel(const Context& dev_ctx,
const DenseTensor& input,
const SparseCooTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out) {
AddmmKernelImpl<T>(dev_ctx, input, x, y, alpha, beta, out);
AddmmKernelImpl<T>(dev_ctx, input, x, y, beta, alpha, out);
}
template <typename T, typename Context>
......@@ -118,10 +118,10 @@ void AddmmCsrDenseKernel(const Context& dev_ctx,
const DenseTensor& input,
const SparseCsrTensor& x,
const DenseTensor& y,
float alpha,
float beta,
float alpha,
DenseTensor* out) {
AddmmKernelImpl<T>(dev_ctx, input, x, y, alpha, beta, out);
AddmmKernelImpl<T>(dev_ctx, input, x, y, beta, alpha, out);
}
} // namespace sparse
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/phi/kernels/gaussian_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/memcpy.h"
......@@ -22,13 +22,13 @@
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
void GaussianKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
std::normal_distribution<T> dist(mean, std);
int64_t size = out->numel();
ctx.template Alloc<T>(out);
......@@ -51,5 +51,4 @@ void GaussianRandomKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
gaussian_random, XPU, ALL_LAYOUT, phi::GaussianRandomKernel, float) {}
PD_REGISTER_KERNEL(gaussian, XPU, ALL_LAYOUT, phi::GaussianKernel, float) {}
......@@ -41,7 +41,7 @@ namespace phi {
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Square, "square", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(BRelu, "brelu", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hard_tanh", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LeakyRelu, "leaky_relu", "alpha");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(ThresholdedRelu,
"thresholded_relu",
......@@ -228,6 +228,8 @@ PD_REGISTER_BASE_KERNEL_NAME(sqrt_grad_grad, sqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(rsqrt_grad_grad, rsqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(celu_grad_grad, celu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(square_grad_grad, square_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh);
PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hard_tanh_grad);
PD_REGISTER_ARG_MAPPING_FN(relu_grad, phi::ReluGradOpArgumentMapping);
......@@ -252,7 +254,7 @@ PD_REGISTER_ARG_MAPPING_FN(tanh_grad_grad,
phi::TanhDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tanh_triple_grad,
phi::TanhTripleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::BReluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::HardTanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(leaky_relu, phi::LeakyReluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad,
phi::LeakyReluGradOpArgumentMapping);
......
......@@ -16,6 +16,11 @@
namespace phi {
KernelSignature AddmmOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"addmm", {"Input", "X", "Y"}, {"Beta", "Alpha"}, {"Out"});
}
KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("addmm_grad",
{"Input", "X", "Y", "Out@GRAD"},
......@@ -25,4 +30,5 @@ KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(addmm, phi::AddmmOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(addmm_grad, phi::AddmmGradOpArgumentMapping);
......@@ -20,35 +20,31 @@ KernelSignature CropTensorOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("ShapeTensor") > 0) {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop_tensor", {"X"}, {"ShapeTensor", "OffsetsTensor"}, {"Out"});
"crop", {"X"}, {"ShapeTensor", "OffsetsTensor"}, {"Out"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature(
"crop_tensor", {"X"}, {"ShapeTensor", "Offsets"}, {"Out"});
"crop", {"X"}, {"ShapeTensor", "Offsets"}, {"Out"});
} else {
return KernelSignature(
"crop_tensor", {"X"}, {"ShapeTensor", "offsets"}, {"Out"});
"crop", {"X"}, {"ShapeTensor", "offsets"}, {"Out"});
}
} else if (ctx.HasInput("Shape")) {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop_tensor", {"X"}, {"Shape", "OffsetsTensor"}, {"Out"});
"crop", {"X"}, {"Shape", "OffsetsTensor"}, {"Out"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature(
"crop_tensor", {"X"}, {"Shape", "Offsets"}, {"Out"});
return KernelSignature("crop", {"X"}, {"Shape", "Offsets"}, {"Out"});
} else {
return KernelSignature(
"crop_tensor", {"X"}, {"Shape", "offsets"}, {"Out"});
return KernelSignature("crop", {"X"}, {"Shape", "offsets"}, {"Out"});
}
} else {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop_tensor", {"X"}, {"shape", "OffsetsTensor"}, {"Out"});
"crop", {"X"}, {"shape", "OffsetsTensor"}, {"Out"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature(
"crop_tensor", {"X"}, {"shape", "Offsets"}, {"Out"});
return KernelSignature("crop", {"X"}, {"shape", "Offsets"}, {"Out"});
} else {
return KernelSignature(
"crop_tensor", {"X"}, {"shape", "offsets"}, {"Out"});
return KernelSignature("crop", {"X"}, {"shape", "offsets"}, {"Out"});
}
}
}
......@@ -57,18 +53,21 @@ KernelSignature CropTensorGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.InputSize("OffsetsTensor") > 0) {
return KernelSignature(
"crop_tensor_grad", {"X", "Out@GRAD"}, {"OffsetsTensor"}, {"X@GRAD"});
"crop_grad", {"X", "Out@GRAD"}, {"OffsetsTensor"}, {"X@GRAD"});
} else if (ctx.HasInput("Offsets")) {
return KernelSignature(
"crop_tensor_grad", {"X", "Out@GRAD"}, {"Offsets"}, {"X@GRAD"});
"crop_grad", {"X", "Out@GRAD"}, {"Offsets"}, {"X@GRAD"});
} else {
return KernelSignature(
"crop_tensor_grad", {"X", "Out@GRAD"}, {"offsets"}, {"X@GRAD"});
"crop_grad", {"X", "Out@GRAD"}, {"offsets"}, {"X@GRAD"});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(crop_tensor, crop);
PD_REGISTER_BASE_KERNEL_NAME(crop_tensor_grad, crop_grad);
PD_REGISTER_ARG_MAPPING_FN(crop_tensor, phi::CropTensorOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(crop_tensor_grad,
phi::CropTensorGradOpArgumentMapping);
......@@ -22,13 +22,11 @@ KernelSignature GaussianRandomOpArgumentMapping(
if (ctx.InputSize("ShapeTensorList") > 0) {
// Infer output shape by Attr("shape") in CompileTime if it is specified.
if (!ctx.IsRuntime() && !shape.empty()) {
return KernelSignature("gaussian_random",
{},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
return KernelSignature(
"gaussian", {}, {"shape", "mean", "std", "seed", "dtype"}, {"Out"});
} else {
return KernelSignature(
"gaussian_random",
"gaussian",
{},
{"ShapeTensorList", "mean", "std", "seed", "dtype"},
{"Out"});
......@@ -36,19 +34,19 @@ KernelSignature GaussianRandomOpArgumentMapping(
}
if (ctx.HasInput("ShapeTensor") && shape.empty()) {
return KernelSignature("gaussian_random",
return KernelSignature("gaussian",
{},
{"ShapeTensor", "mean", "std", "seed", "dtype"},
{"Out"});
}
return KernelSignature("gaussian_random",
{},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
return KernelSignature(
"gaussian", {}, {"shape", "mean", "std", "seed", "dtype"}, {"Out"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(gaussian_random, gaussian);
PD_REGISTER_ARG_MAPPING_FN(gaussian_random,
phi::GaussianRandomOpArgumentMapping);
......@@ -19,12 +19,12 @@ namespace phi {
KernelSignature GraphSendRecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Out_size")) {
return KernelSignature("graph_send_recv",
return KernelSignature("send_u_recv",
{"X", "Src_index", "Dst_index"},
{"reduce_op", "Out_size"},
{"Out", "Dst_count"});
} else {
return KernelSignature("graph_send_recv",
return KernelSignature("send_u_recv",
{"X", "Src_index", "Dst_index"},
{"reduce_op", "out_size"},
{"Out", "Dst_count"});
......@@ -34,7 +34,7 @@ KernelSignature GraphSendRecvOpArgumentMapping(
KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_recv_grad",
"send_u_recv_grad",
{"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"reduce_op"},
{"X@GRAD"});
......@@ -42,6 +42,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(graph_send_recv, send_u_recv);
PD_REGISTER_BASE_KERNEL_NAME(graph_send_recv_grad, send_u_recv_grad);
PD_REGISTER_ARG_MAPPING_FN(graph_send_recv,
phi::GraphSendRecvOpArgumentMapping);
......
......@@ -19,12 +19,12 @@ namespace phi {
KernelSignature GraphSendUERecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Out_size")) {
return KernelSignature("graph_send_ue_recv",
return KernelSignature("send_ue_recv",
{"X", "Y", "Src_index", "Dst_index"},
{"message_op", "reduce_op", "Out_size"},
{"Out", "Dst_count"});
} else {
return KernelSignature("graph_send_ue_recv",
return KernelSignature("send_ue_recv",
{"X", "Y", "Src_index", "Dst_index"},
{"message_op", "reduce_op", "out_size"},
{"Out", "Dst_count"});
......@@ -34,7 +34,7 @@ KernelSignature GraphSendUERecvOpArgumentMapping(
KernelSignature GraphSendUERecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_ue_recv_grad",
"send_ue_recv_grad",
{"X", "Y", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"message_op", "reduce_op"},
{"X@GRAD", "Y@GRAD"});
......@@ -42,6 +42,9 @@ KernelSignature GraphSendUERecvGradOpArgumentMapping(
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(graph_send_ue_recv, send_ue_recv);
PD_REGISTER_BASE_KERNEL_NAME(graph_send_ue_recv_grad, send_ue_recv_grad);
PD_REGISTER_ARG_MAPPING_FN(graph_send_ue_recv,
phi::GraphSendUERecvOpArgumentMapping);
......
......@@ -435,7 +435,7 @@ class NormalInitializer(Initializer):
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.gaussian_random(
out_var = _C_ops.gaussian(
var.shape,
self._mean,
self._std_dev,
......@@ -737,7 +737,7 @@ class XavierInitializer(Initializer):
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.gaussian_random(
out_var = _C_ops.gaussian(
out_var.shape, 0.0, std, self._seed, out_dtype, place
)
else:
......@@ -949,7 +949,7 @@ class MSRAInitializer(Initializer):
std = gain / math.sqrt(float(fan_in))
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.gaussian_random(
out_var = _C_ops.gaussian(
out_var.shape, 0.0, std, self._seed, out_dtype, place
)
else:
......
......@@ -11837,7 +11837,7 @@ def gaussian_random(
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
place = _current_expected_place()
return _C_ops.gaussian_random(
return _C_ops.gaussian(
shape, float(mean), float(std), seed, dtype, place
)
......
......@@ -133,7 +133,7 @@ def send_u_recv(
return out
if in_dygraph_mode():
out_size = convert_out_size_to_list(out_size)
return _C_ops.graph_send_recv(
return _C_ops.send_u_recv(
x, src_index, dst_index, reduce_op.upper(), out_size
)
......@@ -320,7 +320,7 @@ def send_ue_recv(
return out
if in_dygraph_mode():
out_size = convert_out_size_to_list(out_size)
return _C_ops.graph_send_ue_recv(
return _C_ops.send_ue_recv(
x,
y,
src_index,
......@@ -464,16 +464,14 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None):
y = 1.0 / (y + 1e-12)
if in_dygraph_mode():
return _C_ops.graph_send_uv(
x, y, src_index, dst_index, message_op.upper()
)
return _C_ops.send_uv(x, y, src_index, dst_index, message_op.upper())
else:
if _in_legacy_dygraph():
return _legacy_C_ops.graph_send_uv(
x, y, src_index, dst_index, "message_op", message_op.upper()
)
else:
helper = LayerHelper("send_uv", **locals())
helper = LayerHelper("graph_send_uv", **locals())
check_variable_and_dtype(
x,
'x',
......
......@@ -139,7 +139,7 @@ def graph_send_recv(
return out
if in_dygraph_mode():
out_size = convert_out_size_to_list(out_size)
return _C_ops.graph_send_recv(
return _C_ops.send_u_recv(
x, src_index, dst_index, pool_type.upper(), out_size
)
......
......@@ -288,7 +288,7 @@ def hardtanh(x, min=-1.0, max=1.0, name=None):
"""
if in_dygraph_mode():
return _C_ops.brelu(x, min, max)
return _C_ops.hardtanh(x, min, max)
if _in_legacy_dygraph():
return _legacy_C_ops.brelu(x, 't_min', min, 't_max', max)
......
......@@ -106,7 +106,7 @@ class Orthogonal(Initializer):
if framework.in_dygraph_mode():
with no_grad():
place = framework._current_expected_place()
normal_var = _C_ops.gaussian_random(
normal_var = _C_ops.gaussian(
flatten_shape, 0.0, 1.0, self._seed, var.dtype, place
)
q, r = _C_ops.qr(normal_var, 'reduced')
......
......@@ -79,4 +79,4 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
out = paddle.sparse.addmm(input, x, y, 3.0, 2.0)
"""
return _C_ops.sparse_addmm(input, x, y, alpha, beta)
return _C_ops.sparse_addmm(input, x, y, beta, alpha)
......@@ -716,7 +716,7 @@ def crop(x, shape=None, offsets=None, name=None):
shape = x.shape
if in_dygraph_mode():
return _C_ops.crop_tensor(x, shape, offsets)
return _C_ops.crop(x, shape, offsets)
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x}
......
......@@ -1941,7 +1941,7 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
)
if in_dygraph_mode():
return _C_ops.addmm(input, x, y, alpha, beta)
return _C_ops.addmm(input, x, y, beta, alpha)
else:
if _in_legacy_dygraph():
out = _legacy_C_ops.addmm(input, x, y, "Alpha", alpha, "Beta", beta)
......
......@@ -257,7 +257,7 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
place = _current_expected_place()
return _C_ops.gaussian_random(
return _C_ops.gaussian(
shape, float(mean), float(std), seed, dtype, place
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册