未验证 提交 54a4696f 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7660 from reyoung/feature/compare_op_use_elemwise

Make compare_op reuse elemwise_op_funcs
...@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is ...@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s calculated by %s
)DOC", )DOC",
comment.type, comment.equation)); comment.type, comment.equation));
AddAttr<int>("axis",
"(int, default -1). The start dimension index "
"for broadcasting Y onto X.")
.SetDefault(-1)
.EqualGreaterThan(-1);
} }
}; };
...@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); ...@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
...@@ -16,8 +16,4 @@ limitations under the License. */ ...@@ -16,8 +16,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <math.h> #include <math.h>
#include <type_traits> #include <type_traits>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/platform/transform.h" #include "paddle/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -33,18 +34,6 @@ struct LessEqualFunctor { ...@@ -33,18 +34,6 @@ struct LessEqualFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
}; };
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T> template <typename T>
struct EqualFunctor { struct EqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
...@@ -65,14 +54,7 @@ class CompareOpKernel ...@@ -65,14 +54,7 @@ class CompareOpKernel
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE; using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X"); ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context);
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), y->data<T>(),
out->mutable_data<bool>(context.GetPlace()), binary_func);
} }
}; };
......
...@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext> ...@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
}; };
#endif #endif
template <typename Functor, typename T, typename DeviceContext> template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor { class TransformFunctor {
public: public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const DeviceContext& ctx, Functor func) framework::Tensor* z, const DeviceContext& ctx, Functor func)
: x_(x->data<T>()), : x_(x->data<T>()),
y_(y->data<T>()), y_(y->data<T>()),
z_(z->mutable_data<T>(ctx.GetPlace())), z_(z->mutable_data<OutType>(ctx.GetPlace())),
nx_(x->numel()), nx_(x->numel()),
ctx_(ctx), ctx_(ctx),
func_(func) {} func_(func) {}
...@@ -208,7 +209,7 @@ class TransformFunctor { ...@@ -208,7 +209,7 @@ class TransformFunctor {
private: private:
const T* x_; const T* x_;
const T* y_; const T* y_;
T* z_; OutType* z_;
int64_t nx_; int64_t nx_;
const DeviceContext& ctx_; const DeviceContext& ctx_;
Functor func_; Functor func_;
...@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
} }
} }
template <typename Functor, typename DeviceContext, typename T> template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<OutType>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor()); x, y, z, ctx.template device_context<DeviceContext>(), Functor());
auto x_dims = x->dims(); auto x_dims = x->dims();
......
...@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback): ...@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册