未验证 提交 b432d024 编写于 作者: L limingshu 提交者: GitHub

Support Add Sub Mul Max Min Pow binary functors in elementwise system (#33050)

上级 9c52adef
......@@ -21,21 +21,21 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \
template <typename T, typename Enable = void> \
struct Func##Functor { \
struct func { \
using ELEMENT_TYPE = T; \
inline HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
template <typename T>
......@@ -67,10 +67,12 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
ctx, ins, &outs, functor);
cuda_ctx, ins, &outs, axis, functor);
}
};
......@@ -79,19 +81,16 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>, void>);
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor)
#undef REGISTER_CUDA_COMPARE_KERNEL
......@@ -28,11 +28,11 @@ namespace operators {
1. For Unary Op, the length of input array is 1,
e.g. Relu: return args[0] > 0 ? args[0] : 0;
2. For Binary Op, the length of input array is 2,
e.g. Add: return args[0] + args[1];
e.g. Add: return args[0] expr args[1];
*/
template <typename T>
struct CudaAddFunctor {
__device__ __forceinline__ T operator()(const T* args) const {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] + args[1];
}
};
......@@ -44,9 +44,12 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaAddFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
}
};
......
......@@ -72,12 +72,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
if (x->dims() == y->dims()) {
SameDimsElemwiseAdd<platform::CPUDeviceContext, T>
LaunchElementwiseCpuKernel;
SameDimsElemwiseAdd<DeviceContext, T> LaunchElementwiseCpuKernel;
LaunchElementwiseCpuKernel(ctx, x, y, z);
} else {
LaunchBroadcastElementwiseCpuKernel<platform::CPUDeviceContext, T>(ctx, x,
y, z);
LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, x, y, z);
}
}
};
......
......@@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
template <typename T>
struct CudaMaxFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return (args[0] > args[1] ? args[0] : args[1]);
}
};
template <typename T>
class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMaxFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,
......
......@@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
template <typename T>
struct CudaMinFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return (args[0] > args[1] ? args[1] : args[0]);
}
};
template <typename T>
class ElementwiseMinKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMinFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -24,37 +25,65 @@ namespace paddle {
namespace operators {
template <typename T>
struct SameDimsElemwiseMul<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
MulRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x->numel());
for_range(functor);
struct CudaMulFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] * args[1];
}
};
template <>
struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseMulCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int axis = -1;
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NOT_NULL(
x_var, platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
auto* y = ctx.Input<framework::LoDTensor>("Y");
framework::Tensor x, *z;
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
} else if (x_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s.",
y->dims().size()));
auto& x_sele = x_var->Get<framework::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out");
x = x_sele.value();
out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
z->mutable_data<T>(ctx.GetPlace());
outs.emplace_back(z);
ins.emplace_back(&x);
ins.emplace_back(y);
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
}
};
......
......@@ -126,7 +126,6 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
}
};
template <typename T>
struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
......
......@@ -465,7 +465,11 @@ void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
platform::errors::InvalidArgument(
"Currently, only Support binary calculation, "
"but received %d input tensors.\n",
static_cast<int>(ET)));
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
......@@ -502,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel(
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel(
const framework::ExecutionContext &ctx,
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
std::vector<int> dims_size;
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
const auto &cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
cuda_ctx, ins, outs, func);
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
func);
} else {
int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
axis, func);
}
......
......@@ -64,20 +64,24 @@ namespace operators {
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
*/
template <typename T>
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
template <typename OutT>
int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
int axis = -1;
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
ins->emplace_back(x);
z->mutable_data<OutT>(ctx.GetPlace());
outs->emplace_back(z);
ins->emplace_back(x);
if (y != nullptr) {
ins->emplace_back(y);
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis;
}
return axis;
}
/*
......
......@@ -8,10 +8,52 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaPowFunctor {
inline HOSTDEVICE T operator()(const T args[]) const {
return std::pow(args[0], args[1]);
}
};
template <typename T>
struct CudaPowFunctor<
T, typename std::enable_if<std::is_integral<T>::value>::type> {
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
inline HOSTDEVICE T operator()(const T args[]) const {
return std::llrint(std::pow(args[0], args[1]));
}
};
template <typename T>
class ElementwisePowKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaPowFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_pow,
ops::ElementwisePowKernel<paddle::platform::CUDADeviceContext, float>,
......
......@@ -11,8 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -24,37 +23,25 @@ namespace paddle {
namespace operators {
template <typename T>
struct SameDimsElemwiseSub<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
SubRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x->numel());
for_range(functor);
struct CudaSubFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] - args[1];
}
};
template <>
struct SameDimsElemwiseSub<platform::CUDADeviceContext, platform::float16> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseSubCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
template <typename T>
class ElementwiseSubKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaSubFunctor<T>());
}
};
......
......@@ -11,8 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册