提交 3014662d 编写于 作者: Q qijun

Merge branch 'implement_basic_OpKernel' of https://github.com/QiJune/Paddle...

Merge branch 'implement_basic_OpKernel' of https://github.com/QiJune/Paddle into implement_basic_OpKernel
......@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
)
if (APPLE)
......
......@@ -25,9 +25,10 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = 1;
dim_pair[0].second = 0;
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
Eigen::IndexPair<Eigen::DenseIndex>(1, 0)};
// dim_pair[0].first = 1;
// dim_pair[0].second = 0;
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
......
......@@ -27,6 +27,7 @@ public:
auto in0 = context.Input(0)->Get<framework::Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>();
out->mutable_data<T>(context.GetPlace());
auto input = framework::EigenMatrix<T>::From(in0);
auto bias = framework::EigenVector<T>::From(in1);
......
......@@ -27,6 +27,7 @@ public:
void Compute(const framework::KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
auto logits = framework::EigenMatrix<T>::From(input);
auto softmax = framework::EigenMatrix<T>::From(*output);
......@@ -41,19 +42,21 @@ public:
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits - logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) =
(softmax * softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册