未验证 提交 54a4696f 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7660 from reyoung/feature/compare_op_use_elemwise

Make compare_op reuse elemwise_op_funcs
......@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s
)DOC",
comment.type, comment.equation));
AddAttr<int>("axis",
"(int, default -1). The start dimension index "
"for broadcasting Y onto X.")
.SetDefault(-1)
.EqualGreaterThan(-1);
}
};
......@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
......@@ -16,8 +16,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <math.h>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/platform/transform.h"
namespace paddle {
......@@ -33,18 +34,6 @@ struct LessEqualFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
};
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T>
struct EqualFunctor {
using ELEM_TYPE = T;
......@@ -65,14 +54,7 @@ class CompareOpKernel
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), y->data<T>(),
out->mutable_data<bool>(context.GetPlace()), binary_func);
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context);
}
};
......
......@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
};
#endif
template <typename Functor, typename T, typename DeviceContext>
template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const DeviceContext& ctx, Functor func)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<T>(ctx.GetPlace())),
z_(z->mutable_data<OutType>(ctx.GetPlace())),
nx_(x->numel()),
ctx_(ctx),
func_(func) {}
......@@ -208,7 +209,7 @@ class TransformFunctor {
private:
const T* x_;
const T* y_;
T* z_;
OutType* z_;
int64_t nx_;
const DeviceContext& ctx_;
Functor func_;
......@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
}
}
template <typename Functor, typename DeviceContext, typename T>
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext> functor(
z->mutable_data<OutType>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
auto x_dims = x->dims();
......
......@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册