提交 6bac3e17 编写于 作者: D dongzhihong

"remove unused test net modified"

上级 df4fe671
...@@ -22,8 +22,8 @@ template <typename T> ...@@ -22,8 +22,8 @@ template <typename T>
class GaussianRandomKernel : public framework::OpKernel { class GaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean")); float mean = context.op_.GetAttr<float>("mean");
T std = static_cast<T>(context.op_.GetAttr<T>("std")); float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
...@@ -35,7 +35,8 @@ class GaussianRandomKernel : public framework::OpKernel { ...@@ -35,7 +35,8 @@ class GaussianRandomKernel : public framework::OpKernel {
} }
std::mt19937 g(seed); std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std); std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(tensor->dims()); ++i) { ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g); data[i] = distribution(g);
} }
} }
......
...@@ -26,8 +26,8 @@ template <typename T> ...@@ -26,8 +26,8 @@ template <typename T>
class GaussianRandomKernel : public framework::OpKernel { class GaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean")); float mean = context.op_.GetAttr<float>("mean");
T std = static_cast<T>(context.op_.GetAttr<T>("std")); float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
...@@ -40,7 +40,6 @@ class GaussianRandomKernel : public framework::OpKernel { ...@@ -40,7 +40,6 @@ class GaussianRandomKernel : public framework::OpKernel {
&g, CURAND_RNG_PSEUDO_DEFAULT)); &g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean, curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
std); std);
} }
......
...@@ -14,13 +14,15 @@ class GaussianRandomTest(unittest.TestCase): ...@@ -14,13 +14,15 @@ class GaussianRandomTest(unittest.TestCase):
def test_gaussian_random(self, place): def test_gaussian_random(self, place):
scope = core.Scope() scope = core.Scope()
scope.new_var("Out").get_tensor() scope.new_var("Out").get_tensor()
op = Operator( op = Operator(
"gaussian_random", "gaussian_random",
Out="Out", Out="Out",
dims=[1000, 784], dims=[1000, 784],
mean=.0, mean=.0,
std=1., std=1.,
seed=0) seed=10)
op.infer_shape(scope) op.infer_shape(scope)
context = core.DeviceContext.create(place) context = core.DeviceContext.create(place)
op.run(scope, context) op.run(scope, context)
......
...@@ -16,13 +16,13 @@ class TestNet(unittest.TestCase): ...@@ -16,13 +16,13 @@ class TestNet(unittest.TestCase):
net.complete_add_op(True) net.complete_add_op(True)
expected = ''' expected = '''
Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out). Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out).
Op(add_two), inputs:(X, Y), outputs:(Out). Op(add_two), inputs:(X, Y), outputs:(Out).
Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out). Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out).
Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0). Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0).
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0). Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out). Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
''' '''
self.assertEqual(expected, "\n" + str(net)) self.assertEqual(expected, "\n" + str(net))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册