未验证 提交 30f059d6 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports gaussian_random (#55547)

上级 5bfbaa8b
......@@ -81,7 +81,6 @@ std::vector<framework::shape_t> InferShapeForGaussianRandom(
const framework::AttrMapType &attrs) {
CHECK(attrs.count("shape"));
auto shape = absl::get<std::vector<int>>(attrs.at("shape"));
CHECK(!shape.empty()) << "shape attr is empty!";
return {shape};
}
......
......@@ -1299,6 +1299,45 @@ class TestSqueezeOp2D(TestSqueezeOp):
self.target_shape = ()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestGaussianRandomOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.target_shape = ()
def build_paddle_program(self, target):
out = paddle.tensor.random.gaussian(
shape=[],
mean=0.0,
std=0.0,
dtype=self.dtype,
)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("gaussian_random_op")
out = builder.gaussian_random(
[],
0.0,
0.0,
1234,
self.dtype,
)
prog = builder.build()
res = self.get_cinn_output(prog, target, [], [], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册