From 984225ecf198525a134acbda0fb6cab177a59ebd Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 25 Jul 2017 16:07:08 +0800 Subject: [PATCH] "fix operator" --- paddle/framework/operator.cc | 14 ++++- paddle/operators/random_op.cc | 23 ++------ paddle/operators/random_op.cu | 13 ----- paddle/operators/random_op.h | 54 +++++++------------ .../paddle/v2/framework/tests/CMakeLists.txt | 3 +- .../v2/framework/tests/test_random_op.py | 7 +-- 6 files changed, 39 insertions(+), 75 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 1e57e9a20f..18e327089f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "paddle/framework/operator.h" +#include +#include namespace paddle { namespace framework { @@ -95,6 +95,16 @@ std::string OperatorBase::DebugString() const { ss << ", "; } } + ss << "), "; + ss << "Attrs:("; + size_t i = 0; + for (auto& attr : attrs_) { + ss << attr.first; + if (i != attrs_.size() - 1) { + ss << ", "; + } + i++; + } ss << ")."; return ss.str(); } diff --git a/paddle/operators/random_op.cc b/paddle/operators/random_op.cc index 05a3dbd9f4..726f6504e7 100644 --- a/paddle/operators/random_op.cc +++ b/paddle/operators/random_op.cc @@ -13,28 +13,12 @@ limitations under the License. */ #include "paddle/operators/random_op.h" +#include "glog/logging.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { -// using paddle::platform::CPUPlace; -// template -template -bool Gaussian(platform::CPUDeviceContext& ctx, - framework::Tensor* output, - const int size, - const T& mean, - const T& std, - const T& seed) { - auto g = ctx.RandGenerator(seed); - std::normal_distribution distribution(mean, std); - for (int i = 0; i < size; ++i) { - output[i] = distribution(g()); - } - return true; -} - class RandomOp : public framework::OperatorWithKernel { protected: void InferShape( @@ -42,11 +26,10 @@ protected: const std::vector& outputs) const override { PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); - PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, - "Inputs/Outputs of RandomOp must all be set."); + PADDLE_ENFORCE(outputs[0] != nullptr, + "Outputs of RandomOp must all be set."); outputs[0]->Resize( framework::make_ddim(this->GetAttr>("shape"))); - // outputs[0]->set_dims(context.op_.attrs_.at("shape")); } }; diff --git a/paddle/operators/random_op.cu b/paddle/operators/random_op.cu index 85054974ac..b417666c98 100644 --- a/paddle/operators/random_op.cu +++ b/paddle/operators/random_op.cu @@ -1,19 +1,6 @@ #include "paddle/operators/random_op.h" #include "paddle/framework/op_registry.h" -namespace paddle { -namespace operators { - -template -bool Gaussian(platform::CUDADeviceContext &ctx, framework::Tensor* output, - const int size, const T& mean, const T& std, const T& seed) { - auto g = RandGenerator(seed); - return curandGenerateNormal(g, output, size, mean, std); -} - -} // operators -} // paddle - typedef paddle::operators::RandomOpKernel RandomOpKernel_GPU_float; diff --git a/paddle/operators/random_op.h b/paddle/operators/random_op.h index 3eeb1f87c8..f8e1a90a1d 100644 --- a/paddle/operators/random_op.h +++ b/paddle/operators/random_op.h @@ -13,7 +13,9 @@ bool Gaussian(DeviceContext& ctx, const int size, const T& mean, const T& std, - const T& seed); + const T& seed) { + return false; +} template bool Gaussian(platform::CPUDeviceContext& ctx, @@ -21,14 +23,27 @@ bool Gaussian(platform::CPUDeviceContext& ctx, const int size, const T& mean, const T& std, - const T& seed); + const T& seed) { + auto g = ctx.RandGenerator(seed); + std::normal_distribution distribution(mean, std); + for (int i = 0; i < size; ++i) { + output[i] = distribution(g); + } + return true; +} + +#ifndef PADDLE_ONLY_CPU template bool Gaussian(platform::CUDADeviceContext& ctx, framework::Tensor* output, const int size, const T& mean, const T& std, - const T& seed); + const T& seed) { + auto g = RandGenerator(seed); + return curandGenerateNormal(g, output, size, mean, std); +} +#endif template class RandomOpKernel : public framework::OpKernel { @@ -45,41 +60,8 @@ public: mean, std, seed); - // Gaussian(context.device_context_, - // output, - // framework::product(output->dims()), - // mean, std, seed); - // std::default_random_engine generator(seed); - // std::normal_distribution distribution(mean, std); - - // framework::EigenMatrix::From(*output).device(*( - // context.GetEigenDevice())) = - // framework::EigenMatrix::Random(); } }; -// using paddle::platform::CPUPlace; -// template -// class RandomOpKernel : public framework::OpKernel { -// public: -// void Compute(const framework::KernelContext& context) const override { - -// std::unique_ptr generator(seed); -// for(size_t i=0; i < output->size(); ++i) { -// output[i] = distribution(generator()); -// } -// } - -// }; - -// using paddle::platform::GPUPlace; -// template -// class RandomOpKernel : public framework::OpKernel { -// public: -// void Compute(const framework::KernelContext& context) const override { - -// } -// } - } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index b3eb2ef8a8..254e8d37d1 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -12,4 +12,5 @@ add_python_test(test_framework test_mul_op.py test_sigmoid_op.py test_softmax_op.py - test_rowwise_add_op.py) + test_rowwise_add_op.py + test_random_op.py) diff --git a/python/paddle/v2/framework/tests/test_random_op.py b/python/paddle/v2/framework/tests/test_random_op.py index eb69f35edf..e2aa9bdfc2 100644 --- a/python/paddle/v2/framework/tests/test_random_op.py +++ b/python/paddle/v2/framework/tests/test_random_op.py @@ -15,13 +15,14 @@ class TestRandomOp(unittest.TestCase): if scope.get_var(out) is None: scope.create_var(out).get_tensor() - tensor = scope.get_var("Y").get_tensor() + tensor = scope.get_var("Out").get_tensor() op.infer_shape(scope) self.assertEqual([1000, 1000], tensor.shape()) ctx = core.DeviceContext.cpu_context() op.run(scope, ctx) - self.assertAlmostEqual(numpy.std(tensor), 1.0) - self.assertAlmostEqual(numpy.mean(tensor), 5.0) + tensor_array = numpy.array(tensor) + self.assertAlmostEqual(numpy.std(tensor_array), 1.0) + self.assertAlmostEqual(numpy.mean(tensor_array), 5.0) if __name__ == '__main__': -- GitLab