未验证 提交 86685190 编写于 作者: L limingshu 提交者: GitHub

Optimization of pool2d grad (#35389)

* Optimization of pool2d grad, first commit.

* remove useless print codes

* refine codes

* refine codes

* seal more operation into template specialization

* fix template struct error in MaxPool2dGrad.

* Fix header including error

* refine code with comment

* Seal the param-preparation codes into function for common use.

* Seal the param-preparation codes into function for common use.

* Seal the param-preparation into funciton and make it common for other kernels

* polish code and erase useless template speicalization

* Rerun triger

* rerun trigger
上级 9f88d327
......@@ -68,8 +68,9 @@ class AvgPool {
template <class T>
class MaxPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
static constexpr bool use_x = true;
HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += dy * static_cast<T>(x == y);
}
};
......@@ -77,8 +78,9 @@ class MaxPoolGrad {
template <class T>
class AvgPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
static constexpr bool use_x = false;
HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += (scale * dy);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册