提交 1918ad87 编写于 作者: Q qijun

fix gpu build error

上级 36e8e725
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3) nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context)
else() else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3) cc_library(math_function SRCS math_function.cc DEPS cblas device_context)
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
...@@ -109,15 +109,6 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -109,15 +109,6 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
template <>
void Set<platform::CPUPlace, float>(const int n, const float alpha,
float* output,
platform::DeviceContext* context) {
auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context);
framework::EigenVector<float>::Type out(output, n);
out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -126,15 +122,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -126,15 +122,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
template <>
void Set<platform::GPUPlace, float>(const int n, const float alpha,
float* output,
platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
framework::EigenVector<float>::Type out(output, n);
out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -52,7 +52,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -52,7 +52,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <cmath> #include <cmath>
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -78,10 +77,6 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, ...@@ -78,10 +77,6 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
framework::Tensor* matrix_out, T beta, framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context); platform::DeviceContext* context);
template <typename Place, typename T>
void Set(const int n, const T alpha, T* output,
platform::DeviceContext* context);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle(); cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle); ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
ASSERT_NE(nullptr, device_context->stream()); ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册