提交 14cfb8c2 编写于 作者: Q qijun

fix gpu build error

上级 b6c07552
...@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS ...@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=unused-local-typedefs -Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
) )
if (APPLE) if (APPLE)
......
...@@ -24,9 +24,10 @@ template <typename Place, typename T> ...@@ -24,9 +24,10 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
dim_pair[0].first = 1; Eigen::IndexPair<Eigen::DenseIndex>(1, 0)};
dim_pair[0].second = 0; // dim_pair[0].first = 1;
// dim_pair[0].second = 0;
auto input0 = context.Input(0)->Get<framework::Tensor>(); auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<framework::Tensor>();
......
...@@ -26,6 +26,7 @@ public: ...@@ -26,6 +26,7 @@ public:
auto in0 = context.Input(0)->Get<framework::Tensor>(); auto in0 = context.Input(0)->Get<framework::Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>(); auto in1 = context.Input(1)->Get<framework::Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>(); auto* out = context.Output(0)->GetMutable<framework::Tensor>();
out->mutable_data<T>(context.GetPlace());
auto input = in0.matrix<T>(); auto input = in0.matrix<T>();
auto bias = in1.vec<T>(); auto bias = in1.vec<T>();
......
...@@ -26,6 +26,7 @@ public: ...@@ -26,6 +26,7 @@ public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
auto logits = input.matrix<T>(); auto logits = input.matrix<T>();
auto softmax = output->matrix<T>(); auto softmax = output->matrix<T>();
...@@ -40,7 +41,8 @@ public: ...@@ -40,7 +41,8 @@ public:
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1); Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes); Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits - logits.maximum(along_class) auto shifted_logits = (logits -
logits.maximum(along_class)
.eval() .eval()
.reshape(batch_by_one) .reshape(batch_by_one)
.broadcast(one_by_class)); .broadcast(one_by_class));
...@@ -48,7 +50,8 @@ public: ...@@ -48,7 +50,8 @@ public:
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp(); softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) = softmax.device(*(context.GetEigenDevice<Place>())) =
(softmax * softmax.sum(along_class) (softmax *
softmax.sum(along_class)
.inverse() .inverse()
.eval() .eval()
.reshape(batch_by_one) .reshape(batch_by_one)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册