提交 1918ad87 编写于 作者: Q qijun

fix gpu build error

上级 36e8e725
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3)
cc_library(math_function SRCS math_function.cc DEPS cblas device_context)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......@@ -109,15 +109,6 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
template <>
void Set<platform::CPUPlace, float>(const int n, const float alpha,
float* output,
platform::DeviceContext* context) {
auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context);
framework::EigenVector<float>::Type out(output, n);
out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha));
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/math/math_function.h"
namespace paddle {
......@@ -126,15 +122,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
template <>
void Set<platform::GPUPlace, float>(const int n, const float alpha,
float* output,
platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
framework::EigenVector<float>::Type out(output, n);
out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha));
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -52,7 +52,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <cmath>
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
......@@ -78,10 +77,6 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context);
template <typename Place, typename T>
void Set(const int n, const T alpha, T* output,
platform::DeviceContext* context);
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册