diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 09c3f0b1e6f787547b9253d3aeadf70674708ba0..1b0d4c8bdc683b5203a4bc4b3838560cffe00bc8 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -234,8 +234,8 @@ void gemv(const platform::DeviceContext& context, template struct SetConstant; -struct TensorSetConstant { - TensorSetConstant(framework::Tensor* tensor, float value) +struct TensorSetConstantCPU { + TensorSetConstantCPU(framework::Tensor* tensor, float value) : tensor_(tensor), value_(value) {} template void operator()() const { @@ -252,7 +252,7 @@ void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { framework::VisitDataType(framework::ToDataType(tensor->type()), - TensorSetConstant(tensor, value)); + TensorSetConstantCPU(tensor, value)); } struct TensorSetConstantWithPlace : public boost::static_visitor { diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 255e480680499877ff599b96b8336a968cccbb34..817deec94314bdfd2ed7e4b0ba5212c72b813455 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -233,8 +233,8 @@ void gemv(const platform::DeviceContext& context, template struct SetConstant; -struct TensorSetConstant { - TensorSetConstant(const platform::DeviceContext& context, +struct TensorSetConstantGPU { + TensorSetConstantGPU(const platform::DeviceContext& context, framework::Tensor* tensor, float value) : context_(context), tensor_(tensor), value_(value) {} @@ -254,7 +254,7 @@ void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { framework::VisitDataType(framework::ToDataType(tensor->type()), - TensorSetConstant(context, tensor, value)); + TensorSetConstantGPU(context, tensor, value)); } } // namespace math