未验证 提交 4a64ca1e 编写于 作者: L Lijunhui 提交者: GitHub

optimize elementwise_max_grad using new interfaces (#37906)

* init elem_max_grad op

* optimize code and reply review comments

* ternary functors

* apply new reduce func

* move functor to .h

* multi-outputs init

* rearrange code

* modifed functors

* optimizer code

* pass nullptr

* revert the last change as seg fault occurs

* optimize code

* remove inplace

* remove comments
上级 5f5f626b
......@@ -301,5 +301,32 @@ struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
}
};
// Ternary compare
template <typename T>
struct MaxGradXFunctor {
inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const {
return dout * static_cast<T>(x > y);
}
};
template <typename T>
struct MaxGradYFunctor {
inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const {
return dout * static_cast<T>(x <= y);
}
};
template <typename InT, typename OutT>
struct MaxGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(
const InT& x, const InT& y, const InT& dout) {
paddle::framework::Array<OutT, 2> outs;
// dx = dout * (x > y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
// dy = dout * (x <= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
return outs;
}
};
} // namespace operators
} // namespace paddle
......@@ -24,15 +24,41 @@ class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, MaxFunctor<T>());
dev_ctx, ins, &outs, axis, MaxFunctor<T>());
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, MaxGradXFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, MaxGradYFunctor<T>());
}
}
} // namespace operators
} // namespace paddle
......
......@@ -64,6 +64,28 @@ struct MaxGradDy {
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public:
......@@ -74,12 +96,11 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
ElementwiseMaxGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册