未验证 提交 92cae577 编写于 作者: Y YuanRisheng 提交者: GitHub

supplement header file's code (#50826)

上级 bfa217e4
...@@ -39,12 +39,12 @@ DECALRE_COMPARE_KERNEL(Equal) ...@@ -39,12 +39,12 @@ DECALRE_COMPARE_KERNEL(Equal)
DECALRE_COMPARE_KERNEL(NotEqual) DECALRE_COMPARE_KERNEL(NotEqual)
#undef DECALRE_COMPARE_KERNEL #undef DECALRE_COMPARE_KERNEL
#define DECALRE_COMPARE_ALL_KERNEL(compare_all_kernel) \ #define DECALRE_COMPARE_ALL_KERNEL(compare_all) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void compare_all_kernel(const Context& ctx, \ void compare_all##Kernel(const Context& ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
DenseTensor* out); DenseTensor* out);
DECALRE_COMPARE_ALL_KERNEL(EqualAll) DECALRE_COMPARE_ALL_KERNEL(EqualAll)
#undef DECALRE_COMPARE_KERNEL #undef DECALRE_COMPARE_KERNEL
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void EigGardKernel(const Context& dev_ctx, void EigGradKernel(const Context& dev_ctx,
const DenseTensor& out_w, const DenseTensor& out_w,
const DenseTensor& out_v, const DenseTensor& out_v,
const DenseTensor& dout_w, const DenseTensor& dout_w,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void EighGardKernel(const Context& dev_ctx, void EighGradKernel(const Context& dev_ctx,
const DenseTensor& out_w, const DenseTensor& out_w,
const DenseTensor& out_v, const DenseTensor& out_v,
const DenseTensor& dout_w, const DenseTensor& dout_w,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void GatherGradNdKernel(const Context &ctx, void GatherNdGradKernel(const Context &ctx,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &index, const DenseTensor &index,
const DenseTensor &out_grad, const DenseTensor &out_grad,
......
...@@ -36,4 +36,76 @@ void BilinearInterpGradKernel( ...@@ -36,4 +36,76 @@ void BilinearInterpGradKernel(
int align_mode, int align_mode,
DenseTensor* x_grad); DenseTensor* x_grad);
template <typename T, typename Context>
void LinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void TrilinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void NearestInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void BicubicInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
} // namespace phi } // namespace phi
...@@ -61,7 +61,10 @@ DECLARE_SPARSE_UNARY_GRAD_KERNEL(Square) ...@@ -61,7 +61,10 @@ DECLARE_SPARSE_UNARY_GRAD_KERNEL(Square)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sqrt) DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sqrt)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Log1p) DECLARE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Abs) DECLARE_SPARSE_UNARY_GRAD_KERNEL(Abs)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Expm1)
DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor) DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
template <typename T, typename Context> template <typename T, typename Context>
void CastCooGradKernel(const Context& dev_ctx, void CastCooGradKernel(const Context& dev_ctx,
......
...@@ -57,7 +57,20 @@ DECLARE_SPARSE_UNARY_KERNEL(Square) ...@@ -57,7 +57,20 @@ DECLARE_SPARSE_UNARY_KERNEL(Square)
DECLARE_SPARSE_UNARY_KERNEL(Sqrt) DECLARE_SPARSE_UNARY_KERNEL(Sqrt)
DECLARE_SPARSE_UNARY_KERNEL(Log1p) DECLARE_SPARSE_UNARY_KERNEL(Log1p)
DECLARE_SPARSE_UNARY_KERNEL(Abs) DECLARE_SPARSE_UNARY_KERNEL(Abs)
DECLARE_SPARSE_UNARY_KERNEL(Expm1)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6Raw, threshold)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
template <typename T, typename Context>
void Relu6CooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);
template <typename T, typename Context>
void Relu6CsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx, void ScaleCooKernel(const Context& dev_ctx,
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SpectrumNormGradKernel(const Context& dev_ctx, void SpectralNormGradKernel(const Context& dev_ctx,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& u, const DenseTensor& u,
const DenseTensor& v, const DenseTensor& v,
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SpectrumNormKernel(const Context& dev_ctx, void SpectralNormKernel(const Context& dev_ctx,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& u, const DenseTensor& u,
const DenseTensor& v, const DenseTensor& v,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册