提交 14cfb8c2 编写于 作者: Q qijun

fix gpu build error

上级 b6c07552
......@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
)
if (APPLE)
......
......@@ -24,9 +24,10 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = 1;
dim_pair[0].second = 0;
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
Eigen::IndexPair<Eigen::DenseIndex>(1, 0)};
// dim_pair[0].first = 1;
// dim_pair[0].second = 0;
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
......
......@@ -26,6 +26,7 @@ public:
auto in0 = context.Input(0)->Get<framework::Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>();
out->mutable_data<T>(context.GetPlace());
auto input = in0.matrix<T>();
auto bias = in1.vec<T>();
......
......@@ -26,6 +26,7 @@ public:
void Compute(const framework::KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
auto logits = input.matrix<T>();
auto softmax = output->matrix<T>();
......@@ -40,19 +41,21 @@ public:
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits - logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) =
(softmax * softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册