提交 6bac3e17 编写于 作者: D dongzhihong

"remove unused test net modified"

上级 df4fe671
......@@ -22,8 +22,8 @@ template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
......@@ -35,7 +35,8 @@ class GaussianRandomKernel : public framework::OpKernel {
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(tensor->dims()); ++i) {
ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g);
}
}
......
......@@ -26,8 +26,8 @@ template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
......@@ -40,7 +40,6 @@ class GaussianRandomKernel : public framework::OpKernel {
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
std);
}
......
......@@ -14,13 +14,15 @@ class GaussianRandomTest(unittest.TestCase):
def test_gaussian_random(self, place):
scope = core.Scope()
scope.new_var("Out").get_tensor()
op = Operator(
"gaussian_random",
Out="Out",
dims=[1000, 784],
mean=.0,
std=1.,
seed=0)
seed=10)
op.infer_shape(scope)
context = core.DeviceContext.create(place)
op.run(scope, context)
......
......@@ -16,13 +16,13 @@ class TestNet(unittest.TestCase):
net.complete_add_op(True)
expected = '''
Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out).
Op(add_two), inputs:(X, Y), outputs:(Out).
Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out).
Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out).
Op(add_two), inputs:(X, Y), outputs:(Out).
Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out).
Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0).
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
'''
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
'''
self.assertEqual(expected, "\n" + str(net))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册