未验证 提交 c7ba0312 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports softmax and flip, fix infershape (#55470)

上级 922d2481
...@@ -73,8 +73,7 @@ std::shared_ptr<OpStrategy> StrategyForRelu( ...@@ -73,8 +73,7 @@ std::shared_ptr<OpStrategy> StrategyForRelu(
std::vector<framework::shape_t> InferShapeForRelu( std::vector<framework::shape_t> InferShapeForRelu(
const std::vector<framework::shape_t> &inputs_shape, const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) { const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty()) CHECK(!inputs_shape.empty()) << "The inputs is empty! Please check again.";
<< "The input's shape is empty! Please check again.";
std::vector<framework::shape_t> res{inputs_shape[0]}; std::vector<framework::shape_t> res{inputs_shape[0]};
return res; return res;
} }
...@@ -1979,8 +1978,7 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax( ...@@ -1979,8 +1978,7 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax(
std::vector<std::vector<int>> InferShapeForSoftmax( std::vector<std::vector<int>> InferShapeForSoftmax(
const std::vector<std::vector<int>> &inputs_shape, const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) { const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) CHECK(!inputs_shape.empty()) << "The inputs is empty! Please check again.";
<< "The input's shape size is 0! Please check again.";
std::vector<std::vector<int>> res{inputs_shape[0]}; std::vector<std::vector<int>> res{inputs_shape[0]};
return res; return res;
} }
...@@ -2057,8 +2055,7 @@ std::shared_ptr<OpStrategy> StrategyForDropoutInfer( ...@@ -2057,8 +2055,7 @@ std::shared_ptr<OpStrategy> StrategyForDropoutInfer(
std::vector<std::vector<int>> InferShapeForDropoutInfer( std::vector<std::vector<int>> InferShapeForDropoutInfer(
const std::vector<std::vector<int>> &inputs_shape, const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) { const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty()) CHECK(!inputs_shape.empty()) << "The inputs is empty! Please check again.";
<< "The input's shape size is 0! Please check again.";
float dropout_prob = 0; float dropout_prob = 0;
std::string dropout_implementation = "downgrade_in_infer"; std::string dropout_implementation = "downgrade_in_infer";
for (auto &iter : attrs) { for (auto &iter : attrs) {
......
...@@ -575,6 +575,9 @@ create_unit_test(TestUnaryOp, "abs", paddle.abs, "builder.abs") ...@@ -575,6 +575,9 @@ create_unit_test(TestUnaryOp, "abs", paddle.abs, "builder.abs")
create_unit_test( create_unit_test(
TestUnaryOp, "reciprocal", paddle.reciprocal, "builder.reciprocal" TestUnaryOp, "reciprocal", paddle.reciprocal, "builder.reciprocal"
) )
create_unit_test(
TestUnaryOp, "softmax", paddle.nn.functional.softmax, "builder.softmax"
)
# acosh requires input value > 1.0, specific init_input instead of using create_unit_test # acosh requires input value > 1.0, specific init_input instead of using create_unit_test
...@@ -1118,5 +1121,43 @@ class TestMatmulOp(OpTest): ...@@ -1118,5 +1121,43 @@ class TestMatmulOp(OpTest):
self.check_outputs_and_grads() self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestFlipOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.flip(x, axis=[])
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("flip_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.flip(x, [])
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册