未验证 提交 c9699556 编写于 作者: H huangxu96 提交者: GitHub

Optimize where_op and abs_grad_op by the elementwise interface (#39609)

* Optimize the where_op by the elementwise_op funtion

* Modified where_op & abs_grad_op by elementwise interface
上级 867224b2
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
......@@ -20,6 +21,15 @@ namespace platform = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct CondFunctor {
HOSTDEVICE inline CondFunctor() {}
HOSTDEVICE inline T operator()(const bool cond, const T x, const T y) const {
return cond ? x : y;
}
};
template <typename T>
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
const T* y, T* out) {
......@@ -63,10 +73,11 @@ class WhereKernel<platform::CUDADeviceContext, T>
auto stream = context.cuda_device_context().stream();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
WhereCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, cond_data, x_data, y_data, out_data);
auto functor = CondFunctor<T>();
std::vector<const framework::Tensor*> ins = {condition, X, Y};
std::vector<framework::Tensor*> outs = {out};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
};
......
......@@ -154,6 +154,53 @@ struct AbsFunctor<T, NoComplex<T, Real<T>>> {
int64_t numel_;
};
template <typename T>
struct AbsGradCUDAFunctor {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline T operator()(const T x, const T dout) const {
T output;
if (x == T(0)) {
output = T(0);
} else {
output = T(dout) * (x / T(std::abs(x)));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<float>> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::complex<float> operator()(
const phi::dtype::complex<float> x, const float dout) const {
phi::dtype::complex<float> output;
if (x == phi::dtype::complex<float>(0)) {
output = phi::dtype::complex<float>(0);
} else {
output = phi::dtype::complex<float>(dout) *
(x / phi::dtype::complex<float>(abs(x)));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<double>> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::complex<double> operator()(
const phi::dtype::complex<double> x, const double dout) const {
phi::dtype::complex<double> output;
if (x == phi::dtype::complex<double>(0)) {
output = phi::dtype::complex<double>(0);
} else {
output = phi::dtype::complex<double>(dout) *
(x / phi::dtype::complex<double>(abs(x)));
}
return output;
}
};
template <typename T>
struct AbsGradFunctor {
AbsGradFunctor(const Real<T>* dout, const T* x, T* output, int64_t numel)
......
......@@ -17,9 +17,30 @@
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/abs_grad_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi {
#if defined(__NVCC__)
template <typename T>
void AbsGradKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
std::vector<const DenseTensor*> ins = {&x, &dout};
std::vector<DenseTensor*> outs = {dx};
dev_ctx.Alloc<T>(dx);
phi::funcs::AbsGradCUDAFunctor<T> abs_grad_cuda_functor;
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, abs_grad_cuda_functor);
}
template <typename T, typename Context>
void AbsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
AbsGradKernelImpl<T>(dev_ctx, x, dout, dx);
}
#else
template <typename T, typename Context>
void AbsGradKernel(const Context& ctx,
const DenseTensor& x,
......@@ -37,6 +58,7 @@ void AbsGradKernel(const Context& ctx,
for_range(functor);
}
#endif
template <typename T, typename Context>
void AbsDoubleGradKernel(const Context& ctx,
const DenseTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册