提交 3014662d 编写于 作者: Q qijun

Merge branch 'implement_basic_OpKernel' of https://github.com/QiJune/Paddle...

Merge branch 'implement_basic_OpKernel' of https://github.com/QiJune/Paddle into implement_basic_OpKernel
...@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS ...@@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=literal-suffix -Wno-error=literal-suffix
-Wno-error=unused-local-typedefs -Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
) )
if (APPLE) if (APPLE)
......
...@@ -25,9 +25,10 @@ template <typename Place, typename T> ...@@ -25,9 +25,10 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
dim_pair[0].first = 1; Eigen::IndexPair<Eigen::DenseIndex>(1, 0)};
dim_pair[0].second = 0; // dim_pair[0].first = 1;
// dim_pair[0].second = 0;
auto input0 = context.Input(0)->Get<framework::Tensor>(); auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<framework::Tensor>();
......
...@@ -27,6 +27,7 @@ public: ...@@ -27,6 +27,7 @@ public:
auto in0 = context.Input(0)->Get<framework::Tensor>(); auto in0 = context.Input(0)->Get<framework::Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>(); auto in1 = context.Input(1)->Get<framework::Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>(); auto* out = context.Output(0)->GetMutable<framework::Tensor>();
out->mutable_data<T>(context.GetPlace());
auto input = framework::EigenMatrix<T>::From(in0); auto input = framework::EigenMatrix<T>::From(in0);
auto bias = framework::EigenVector<T>::From(in1); auto bias = framework::EigenVector<T>::From(in1);
......
...@@ -27,6 +27,7 @@ public: ...@@ -27,6 +27,7 @@ public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
auto logits = framework::EigenMatrix<T>::From(input); auto logits = framework::EigenMatrix<T>::From(input);
auto softmax = framework::EigenMatrix<T>::From(*output); auto softmax = framework::EigenMatrix<T>::From(*output);
...@@ -41,19 +42,21 @@ public: ...@@ -41,19 +42,21 @@ public:
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1); Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes); Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits - logits.maximum(along_class) auto shifted_logits = (logits -
.eval() logits.maximum(along_class)
.reshape(batch_by_one) .eval()
.broadcast(one_by_class)); .reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp(); softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) = softmax.device(*(context.GetEigenDevice<Place>())) =
(softmax * softmax.sum(along_class) (softmax *
.inverse() softmax.sum(along_class)
.eval() .inverse()
.reshape(batch_by_one) .eval()
.broadcast(one_by_class)); .reshape(batch_by_one)
.broadcast(one_by_class));
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册