未验证 提交 6bfb8152 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim-Relevant] Tensor supports compare operants (#51713)

* [Tensor Operants & Prim-Relevant] Tensor supports compare operants

* fix dependence of test_comp_static

* fix unit test
上级 af95a8b4
...@@ -2,6 +2,12 @@ ...@@ -2,6 +2,12 @@
- subtract - subtract
- multiply - multiply
- divide - divide
- less_equal
- less_than
- equal
- not_equal
- greater_equal
- greater_than
- bitwise_and - bitwise_and
- bitwise_not - bitwise_not
- bitwise_or - bitwise_or
......
...@@ -38,7 +38,8 @@ cc_test_old( ...@@ -38,7 +38,8 @@ cc_test_old(
static_global_utils static_global_utils
static_tensor_operants static_tensor_operants
tensor_api tensor_api
operants_manager) operants_manager
generated_static_op)
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library( cc_library(
......
...@@ -35,6 +35,12 @@ PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); ...@@ -35,6 +35,12 @@ PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(not_equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_than, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
...@@ -46,6 +52,12 @@ PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT); ...@@ -46,6 +52,12 @@ PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_than, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(not_equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_than, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
...@@ -151,6 +163,50 @@ TEST(EagerPrim, LogicalOperantsTest) { ...@@ -151,6 +163,50 @@ TEST(EagerPrim, LogicalOperantsTest) {
EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]); EXPECT_EQ(out0.data<int>()[0], out1.data<int>()[0]);
} }
TEST(EagerPrim, CompareOperantsTest) {
// 1. Initialized
eager_test::InitEnv(paddle::platform::CPUPlace());
FLAGS_tensor_operants_mode = "eager";
paddle::prim::InitTensorOperants();
// 2. pre
paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32});
paddle::Tensor tensor0 =
::egr::egr_utils_api::CreateTensorWithValue(ddim,
paddle::platform::CPUPlace(),
phi::DataType::INT32,
phi::DataLayout::NCHW,
1 /*value*/,
true /*is_leaf*/);
::egr::egr_utils_api::RetainGradForTensor(tensor0);
paddle::Tensor tensor1 =
::egr::egr_utils_api::CreateTensorWithValue(ddim,
paddle::platform::CPUPlace(),
phi::DataType::INT32,
phi::DataLayout::NCHW,
0 /*value*/,
true /*is_leaf*/);
::egr::egr_utils_api::RetainGradForTensor(tensor1);
// 3. Run Forward once
paddle::Tensor out0 = (tensor0 < tensor1);
paddle::Tensor out1 = less_than_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
out0 = (tensor0 <= tensor1);
out1 = less_equal_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
out0 = (tensor0 == tensor1);
out1 = equal_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
out0 = (tensor0 != tensor1);
out1 = not_equal_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
out0 = (tensor0 > tensor1);
out1 = greater_than_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
out0 = (tensor0 >= tensor1);
out1 = greater_equal_ad_func(tensor0, tensor1);
EXPECT_EQ(out0.data<bool>()[0], out1.data<bool>()[0]);
}
TEST(EagerPrim, TestFlags) { TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetBwdPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
......
...@@ -38,6 +38,12 @@ PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); ...@@ -38,6 +38,12 @@ PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(not_equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_equal, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_than, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT);
...@@ -51,6 +57,12 @@ PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT); ...@@ -51,6 +57,12 @@ PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(less_than, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(not_equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_equal, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(greater_than, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT);
...@@ -429,6 +441,99 @@ TEST(StaticCompositeGradMaker, LogicalOperantsTest) { ...@@ -429,6 +441,99 @@ TEST(StaticCompositeGradMaker, LogicalOperantsTest) {
std::size_t(1)); std::size_t(1));
} }
TEST(StaticCompositeGradMaker, CompareOperantsTest) {
// Initialized environment
FLAGS_tensor_operants_mode = "static";
paddle::OperantsManager::Instance().static_operants.reset(
new paddle::prim::StaticTensorOperants());
TestBaseProgram base_program = TestBaseProgram();
auto* target_block = base_program.GetBlock(0);
std::vector<int64_t> shape = {2, 2};
StaticCompositeContext::Instance().SetBlock(target_block);
Tensor x0 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x0_name =
std::static_pointer_cast<prim::DescTensor>(x0.impl())->Name();
Tensor x1 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x1_name =
std::static_pointer_cast<prim::DescTensor>(x1.impl())->Name();
Tensor x2 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x2_name =
std::static_pointer_cast<prim::DescTensor>(x2.impl())->Name();
Tensor x3 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x3_name =
std::static_pointer_cast<prim::DescTensor>(x3.impl())->Name();
Tensor x4 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x4_name =
std::static_pointer_cast<prim::DescTensor>(x4.impl())->Name();
Tensor x5 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x5_name =
std::static_pointer_cast<prim::DescTensor>(x5.impl())->Name();
Tensor x6 = prim::empty<prim::DescTensor>(
shape, phi::DataType::INT32, phi::CPUPlace());
std::string x6_name =
std::static_pointer_cast<prim::DescTensor>(x6.impl())->Name();
Tensor out_less = (x0 < x1);
Tensor out_less_equal = (out_less <= x2);
Tensor out_equal = (out_less_equal == x3);
Tensor out_not_equal = (out_equal != x4);
Tensor out_greater = (out_not_equal > x5);
Tensor out_greater_equal = (out_greater >= x6);
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(6));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "less_than");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X")[0], x0_name);
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("Y")[0], x1_name);
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out").size(),
std::size_t(1));
ASSERT_EQ(target_block->AllOps()[1]->Type(), "less_equal");
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[1]->Inputs().at("Y")[0], x2_name);
ASSERT_EQ(target_block->AllOps()[1]->Outputs().at("Out").size(),
std::size_t(1));
ASSERT_EQ(target_block->AllOps()[2]->Type(), "equal");
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[2]->Inputs().at("Y")[0], x3_name);
ASSERT_EQ(target_block->AllOps()[2]->Outputs().at("Out").size(),
std::size_t(1));
ASSERT_EQ(target_block->AllOps()[3]->Type(), "not_equal");
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[3]->Inputs().at("Y")[0], x4_name);
ASSERT_EQ(target_block->AllOps()[3]->Outputs().at("Out").size(),
std::size_t(1));
ASSERT_EQ(target_block->AllOps()[4]->Type(), "greater_than");
ASSERT_EQ(target_block->AllOps()[4]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[4]->Inputs().at("Y")[0], x5_name);
ASSERT_EQ(target_block->AllOps()[4]->Outputs().at("Out").size(),
std::size_t(1));
ASSERT_EQ(target_block->AllOps()[5]->Type(), "greater_equal");
ASSERT_EQ(target_block->AllOps()[5]->Inputs().at("Y").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(target_block->AllOps()[5]->Inputs().at("Y")[0], x6_name);
ASSERT_EQ(target_block->AllOps()[5]->Outputs().at("Out").size(),
std::size_t(1));
}
TEST(StaticPrim, TestFlags) { TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetBwdPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
...@@ -445,6 +550,12 @@ USE_OP_ITSELF(elementwise_mul); ...@@ -445,6 +550,12 @@ USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_sub); USE_OP_ITSELF(elementwise_sub);
USE_OP_ITSELF(elementwise_pow); USE_OP_ITSELF(elementwise_pow);
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP_ITSELF(less_equal);
USE_OP_ITSELF(less_than);
USE_OP_ITSELF(equal);
USE_OP_ITSELF(not_equal);
USE_OP_ITSELF(greater_equal);
USE_OP_ITSELF(greater_than);
USE_OP_ITSELF(bitwise_xor); USE_OP_ITSELF(bitwise_xor);
USE_OP_ITSELF(bitwise_and); USE_OP_ITSELF(bitwise_and);
USE_OP_ITSELF(bitwise_not); USE_OP_ITSELF(bitwise_not);
......
...@@ -534,29 +534,23 @@ class PADDLE_API Tensor final { ...@@ -534,29 +534,23 @@ class PADDLE_API Tensor final {
* @return Tensor * @return Tensor
*/ */
Tensor operator+(const Tensor& other) const; Tensor operator+(const Tensor& other) const;
Tensor operator-(const Tensor& other) const; Tensor operator-(const Tensor& other) const;
Tensor operator*(const Tensor& other) const; Tensor operator*(const Tensor& other) const;
Tensor operator/(const Tensor& other) const; Tensor operator/(const Tensor& other) const;
Tensor operator+(const Scalar& other) const; Tensor operator+(const Scalar& other) const;
Tensor operator-(const Scalar& other) const; Tensor operator-(const Scalar& other) const;
Tensor operator*(const Scalar& other) const; Tensor operator*(const Scalar& other) const;
Tensor operator/(const Scalar& other) const; Tensor operator/(const Scalar& other) const;
Tensor operator<(const Tensor& other) const;
Tensor operator<=(const Tensor& other) const;
Tensor operator==(const Tensor& other) const;
Tensor operator!=(const Tensor& other) const;
Tensor operator>(const Tensor& other) const;
Tensor operator>=(const Tensor& other) const;
Tensor operator-() const; Tensor operator-() const;
Tensor operator~() const; Tensor operator~() const;
Tensor operator&(const Tensor& other) const; Tensor operator&(const Tensor& other) const;
Tensor operator|(const Tensor& other) const; Tensor operator|(const Tensor& other) const;
Tensor operator^(const Tensor& other) const; Tensor operator^(const Tensor& other) const;
/* Part 8: Autograd methods */ /* Part 8: Autograd methods */
...@@ -678,6 +672,12 @@ class PADDLE_API Tensor final { ...@@ -678,6 +672,12 @@ class PADDLE_API Tensor final {
Tensor divide(const Scalar& y) const; Tensor divide(const Scalar& y) const;
Tensor multiply(const Scalar& y) const; Tensor multiply(const Scalar& y) const;
Tensor subtract(const Scalar& y) const; Tensor subtract(const Scalar& y) const;
Tensor less_equal(const Tensor& y) const;
Tensor less_than(const Tensor& y) const;
Tensor equal(const Tensor& y) const;
Tensor not_equal(const Tensor& y) const;
Tensor greater_equal(const Tensor& y) const;
Tensor greater_than(const Tensor& y) const;
Tensor bitwise_and(const Tensor& y) const; Tensor bitwise_and(const Tensor& y) const;
Tensor bitwise_or(const Tensor& y) const; Tensor bitwise_or(const Tensor& y) const;
Tensor bitwise_xor(const Tensor& y) const; Tensor bitwise_xor(const Tensor& y) const;
......
...@@ -144,6 +144,30 @@ Tensor Tensor::subtract(const Scalar& y) const { ...@@ -144,6 +144,30 @@ Tensor Tensor::subtract(const Scalar& y) const {
return paddle::OperantsManager::Instance().subtract(static_cast<const Tensor &>(*this), y); return paddle::OperantsManager::Instance().subtract(static_cast<const Tensor &>(*this), y);
} }
Tensor Tensor::operator<(const Tensor &other) const {
return less_than(other);
}
Tensor Tensor::operator<=(const Tensor &other) const {
return less_equal(other);
}
Tensor Tensor::operator==(const Tensor &other) const {
return equal(other);
}
Tensor Tensor::operator!=(const Tensor &other) const {
return not_equal(other);
}
Tensor Tensor::operator>(const Tensor &other) const {
return greater_than(other);
}
Tensor Tensor::operator>=(const Tensor &other) const {
return greater_equal(other);
}
Tensor Tensor::operator-() const { Tensor Tensor::operator-() const {
return scale(-1.0, 0.0, true); return scale(-1.0, 0.0, true);
} }
......
...@@ -4,6 +4,12 @@ ...@@ -4,6 +4,12 @@
- subtract - subtract
- multiply - multiply
- divide - divide
- less_equal
- less_than
- equal
- not_equal
- greater_equal
- greater_than
- bitwise_and - bitwise_and
- bitwise_not - bitwise_not
- bitwise_or - bitwise_or
......
...@@ -453,3 +453,93 @@ PD_BUILD_OP(custom_logical_not) ...@@ -453,3 +453,93 @@ PD_BUILD_OP(custom_logical_not)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(NotForward)); .SetKernelFn(PD_KERNEL(NotForward));
// out = (x < y)
std::vector<paddle::Tensor> LessThanForward(const paddle::Tensor& x,
const paddle::Tensor& y) {
if (x.is_cpu() || x.is_gpu()) {
return {x < y};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_less_than)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(LessThanForward));
// out = (x <= y)
std::vector<paddle::Tensor> LessEqualForward(const paddle::Tensor& x,
const paddle::Tensor& y) {
if (x.is_cpu() || x.is_gpu()) {
return {x <= y};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_less_equal)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(LessEqualForward));
// out = (x == y)
std::vector<paddle::Tensor> EqualForward(const paddle::Tensor& x,
const paddle::Tensor& y) {
if (x.is_cpu() || x.is_gpu()) {
return {x == y};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_equal)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(EqualForward));
// out = (x != y)
std::vector<paddle::Tensor> NotEqualForward(const paddle::Tensor& x,
const paddle::Tensor& y) {
if (x.is_cpu() || x.is_gpu()) {
return {x != y};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_not_equal)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(NotEqualForward));
// out = (x > y)
std::vector<paddle::Tensor> GreaterThanForward(const paddle::Tensor& x,
const paddle::Tensor& y) {
if (x.is_cpu() || x.is_gpu()) {
return {x > y};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_greater_than)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(GreaterThanForward));
// out = (x >= y)
std::vector<paddle::Tensor> GreaterEqualForward(const paddle::Tensor& x,
const paddle::Tensor& y) {
if (x.is_cpu() || x.is_gpu()) {
return {x >= y};
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_greater_equal)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(GreaterEqualForward));
...@@ -247,6 +247,7 @@ class TestJITLoad(unittest.TestCase): ...@@ -247,6 +247,7 @@ class TestJITLoad(unittest.TestCase):
self._test_static() self._test_static()
self._test_dynamic() self._test_dynamic()
self._test_logical_operants() self._test_logical_operants()
self._test_compare_operants()
def _test_static(self): def _test_static(self):
for device in self.devices: for device in self.devices:
...@@ -355,6 +356,38 @@ class TestJITLoad(unittest.TestCase): ...@@ -355,6 +356,38 @@ class TestJITLoad(unittest.TestCase):
pd_out = paddle.bitwise_not(x) pd_out = paddle.bitwise_not(x)
np.testing.assert_equal(out.numpy(), pd_out.numpy()) np.testing.assert_equal(out.numpy(), pd_out.numpy())
def _test_compare_operants(self):
for device in self.devices:
paddle.set_device(device)
np_x = paddle.randint(0, 2, [4, 8])
x = paddle.to_tensor(np_x, dtype="int32")
np_y = paddle.randint(0, 2, [4, 8])
y = paddle.to_tensor(np_y, dtype="int32")
out = self.custom_module.custom_less_than(x, y)
pd_out = paddle.less_than(x, y)
np.testing.assert_equal(out.numpy(), pd_out.numpy())
out = self.custom_module.custom_less_equal(x, y)
pd_out = paddle.less_equal(x, y)
np.testing.assert_equal(out.numpy(), pd_out.numpy())
out = self.custom_module.custom_equal(x, y)
pd_out = paddle.equal(x, y)
np.testing.assert_equal(out.numpy(), pd_out.numpy())
out = self.custom_module.custom_not_equal(x, y)
pd_out = paddle.not_equal(x, y)
np.testing.assert_equal(out.numpy(), pd_out.numpy())
out = self.custom_module.custom_greater_than(x, y)
pd_out = paddle.greater_than(x, y)
np.testing.assert_equal(out.numpy(), pd_out.numpy())
out = self.custom_module.custom_greater_equal(x, y)
pd_out = paddle.greater_equal(x, y)
np.testing.assert_equal(out.numpy(), pd_out.numpy())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册