未验证 提交 92cae577 编写于 作者: Y YuanRisheng 提交者: GitHub

supplement header file's code (#50826)

上级 bfa217e4
......@@ -39,12 +39,12 @@ DECALRE_COMPARE_KERNEL(Equal)
DECALRE_COMPARE_KERNEL(NotEqual)
#undef DECALRE_COMPARE_KERNEL
#define DECALRE_COMPARE_ALL_KERNEL(compare_all_kernel) \
template <typename T, typename Context> \
void compare_all_kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out);
#define DECALRE_COMPARE_ALL_KERNEL(compare_all) \
template <typename T, typename Context> \
void compare_all##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out);
DECALRE_COMPARE_ALL_KERNEL(EqualAll)
#undef DECALRE_COMPARE_KERNEL
......
......@@ -19,7 +19,7 @@
namespace phi {
template <typename T, typename Context>
void EigGardKernel(const Context& dev_ctx,
void EigGradKernel(const Context& dev_ctx,
const DenseTensor& out_w,
const DenseTensor& out_v,
const DenseTensor& dout_w,
......
......@@ -19,7 +19,7 @@
namespace phi {
template <typename T, typename Context>
void EighGardKernel(const Context& dev_ctx,
void EighGradKernel(const Context& dev_ctx,
const DenseTensor& out_w,
const DenseTensor& out_v,
const DenseTensor& dout_w,
......
......@@ -19,7 +19,7 @@
namespace phi {
template <typename T, typename Context>
void GatherGradNdKernel(const Context &ctx,
void GatherNdGradKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &index,
const DenseTensor &out_grad,
......
......@@ -36,4 +36,76 @@ void BilinearInterpGradKernel(
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void LinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void TrilinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void NearestInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
template <typename T, typename Context>
void BicubicInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
} // namespace phi
......@@ -61,7 +61,10 @@ DECLARE_SPARSE_UNARY_GRAD_KERNEL(Square)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sqrt)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Abs)
DECLARE_SPARSE_UNARY_GRAD_KERNEL(Expm1)
DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
template <typename T, typename Context>
void CastCooGradKernel(const Context& dev_ctx,
......
......@@ -57,7 +57,20 @@ DECLARE_SPARSE_UNARY_KERNEL(Square)
DECLARE_SPARSE_UNARY_KERNEL(Sqrt)
DECLARE_SPARSE_UNARY_KERNEL(Log1p)
DECLARE_SPARSE_UNARY_KERNEL(Abs)
DECLARE_SPARSE_UNARY_KERNEL(Expm1)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6Raw, threshold)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
template <typename T, typename Context>
void Relu6CooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);
template <typename T, typename Context>
void Relu6CsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out);
template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx,
......
......@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void SpectrumNormGradKernel(const Context& dev_ctx,
void SpectralNormGradKernel(const Context& dev_ctx,
const DenseTensor& weight,
const DenseTensor& u,
const DenseTensor& v,
......
......@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void SpectrumNormKernel(const Context& dev_ctx,
void SpectralNormKernel(const Context& dev_ctx,
const DenseTensor& weight,
const DenseTensor& u,
const DenseTensor& v,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册