未验证 提交 c2f825d7 编写于 作者: L Lijunhui 提交者: GitHub

optimize elementwise_min_grad using new reduce interface (#38236)

* ini commit

* multi-outputs init commit

* optimize code

* remove inplace
上级 12c5b1fe
......@@ -233,6 +233,32 @@ struct FMinFunctor<int64_t> {
}
};
template <typename T>
struct MinGradXFunctor {
inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const {
return dout * static_cast<T>(x < y);
}
};
template <typename T>
struct MinGradYFunctor {
inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const {
return dout * static_cast<T>(x >= y);
}
};
template <typename InT, typename OutT>
struct MinGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(
const InT& x, const InT& y, const InT& dout) {
paddle::framework::Array<OutT, 2> outs;
// dx = dout * (x < y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x < y));
// dy = dout * (x >= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x >= y));
return outs;
}
};
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
......
......@@ -24,15 +24,41 @@ class ElementwiseMinKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, MinFunctor<T>());
dev_ctx, ins, &outs, axis, MinFunctor<T>());
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMinGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MinGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, MinGradXFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, MinGradYFunctor<T>());
}
}
} // namespace operators
} // namespace paddle
......
......@@ -86,6 +86,28 @@ struct MinGradDy<platform::float16> {
};
#endif
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseMinGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMinGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public:
......@@ -99,9 +121,7 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
ElementwiseMinGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册