From aadb098138efafc60eaa4b902db04f78db1e62b4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 7 Nov 2017 15:13:36 -0800 Subject: [PATCH] Add `op::math::set_constant` without template --- paddle/operators/math/math_function.cc | 48 +++++++++++++++++++++ paddle/operators/math/math_function.cu | 24 +++++++++++ paddle/operators/math/math_function.h | 7 +++ paddle/operators/math/math_function_test.cc | 12 ++++++ 4 files changed, 91 insertions(+) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 2a9c09a0f1..175df2030d 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/math_function.h" +#include "paddle/framework/data_type.h" namespace paddle { namespace operators { @@ -233,6 +234,53 @@ void gemv(const platform::DeviceContext& context, template struct SetConstant; +struct TensorSetConstant { + TensorSetConstant(framework::Tensor* tensor, float value) + : tensor_(tensor), value_(value) {} + template + void operator()() const { + auto cpu = platform::CPUPlace(); + auto* begin = tensor_->mutable_data(cpu); + std::fill(begin, begin + tensor_->numel(), static_cast(value_)); + } + framework::Tensor* tensor_; + float value_; +}; + +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstant(tensor, value)); +} + +struct TensorSetConstantWithPlace : public boost::static_visitor { + TensorSetConstantWithPlace(const platform::DeviceContext& context, + framework::Tensor* tensor, float value) + : context_(context), tensor_(tensor), value_(value) {} + + template + void operator()(Place place) const { + set_constant_with_place(context_, tensor_, value_); + } + + const platform::DeviceContext& context_; + framework::Tensor* tensor_; + float value_; +}; + +void set_constant(const platform::DeviceContext& context, + framework::Tensor* tensor, float value) { +#ifdef PADDLE_WITH_CUDA + boost::apply_visitor(TensorSetConstantWithPlace(context, tensor, value), + tensor->place()); +#else + TensorSetConstantWithPlace func(context, tensor, value); + func(platform::CPUPlace()); +#endif +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index e6fd8bf235..3a216993ac 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -232,6 +232,30 @@ void gemv(const platform::DeviceContext& context, template struct SetConstant; +struct TensorSetConstant { + TensorSetConstant(const platform::DeviceContext& context, + framework::Tensor* tensor, float value) + : context_(context), tensor_(tensor), value_(value) {} + + template + void operator()() const { + SetConstant functor; + functor(context_, tensor_, static_cast(value_)); + } + + const platform::DeviceContext& context_; + framework::Tensor* tensor_; + float value_; +}; + +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstant(context, tensor, value)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 3bb5aa0332..1c9eabb2b7 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -108,6 +108,13 @@ struct SetConstant { } }; +template +void set_constant_with_place(const platform::DeviceContext& context, + framework::Tensor* tensor, float value); + +void set_constant(const platform::DeviceContext& context, + framework::Tensor* tensor, float value); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 7d84ad9aad..983c9fdcff 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -139,3 +139,15 @@ TEST(math_function, gemv) { GemvTest(12, 7, true); GemvTest(7, 9, true); } + +TEST(math_funciton, set_constant) { + paddle::framework::Tensor t; + t.Resize({10, 10}); + t.mutable_data(paddle::platform::CPUPlace()); + auto* ctx = new paddle::platform::CPUDeviceContext(); + paddle::operators::math::set_constant(*ctx, &t, 10); + for (int64_t i = 0; i < t.numel(); ++i) { + PADDLE_ENFORCE_EQ(10, t.data()[i]); + } + delete ctx; +} -- GitLab