未验证 提交 0285a2b9 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10371 from reyoung/refine_code

Polish MatMul, clean copy & paste code
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -17,9 +17,9 @@ limitations under the License. */ ...@@ -17,9 +17,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/operators/math/vol2col.h"
namespace paddle { namespace paddle {
...@@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix, blas.MatMul(filter_slice, col_matrix, &out_slice);
false, T(1.0), &out_slice, T(0.0));
} }
} }
} }
...@@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
...@@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice); col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true, blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix);
out_grad_slice, false, T(1.0),
&col_matrix, T(0.0));
if (is_expand && data_dim == 2U) { if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides, col2im(dev_ctx, col, dilations, strides,
...@@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step); filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false, blas.MatMul(out_grad_slice, false, col_matrix, true,
col_matrix, true, T(1.0), &filter_grad_slice);
&filter_grad_slice, T(1.0));
} }
} }
} }
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/operators/math/vol2col.h"
namespace paddle { namespace paddle {
...@@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, output, static_cast<T>(0)); set_zero(dev_ctx, output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
...@@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false, blas.MatMul(filter, true, input_batch, false, &col_matrix);
static_cast<T>(1.0), &col_matrix,
static_cast<T>(0.0));
if (data_dim == 2U) { if (data_dim == 2U) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
...@@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad || filter_grad) { if (input_grad || filter_grad) {
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
...@@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w) // d, h, w)
math::matmul<DeviceContext, T>( blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
} }
if (filter_grad) { if (filter_grad) {
// input batch // input batch
...@@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w) // k_h * k_w)
math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix, blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
} }
} }
} }
......
...@@ -14,11 +14,10 @@ limitations under the License. */ ...@@ -14,11 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -46,9 +46,9 @@ class RowwiseMean2D<platform::CUDADeviceContext, T> { ...@@ -46,9 +46,9 @@ class RowwiseMean2D<platform::CUDADeviceContext, T> {
} }
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) { const framework::Tensor& input, framework::Tensor* out) {
math::gemv<platform::CUDADeviceContext, T>( math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
context, false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
0., out->data<T>()); out->data<T>());
} }
private: private:
...@@ -93,9 +93,9 @@ class ColwiseSum2D<platform::CUDADeviceContext, T> { ...@@ -93,9 +93,9 @@ class ColwiseSum2D<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) { const framework::Tensor& input, framework::Tensor* out) {
math::gemv<platform::CUDADeviceContext, T>( math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
context, true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
0., out->data<T>()); out->data<T>());
} }
private: private:
......
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto cand_act = math::detail::GetActivationType( auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
...@@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight, blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), &gate_t, &gate_t, static_cast<T>(1.0));
static_cast<T>(1.0));
} else if (hidden_t0) { } else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say // If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped. // the H0 is zeros, the calculation W_h * H0 will be skiped.
...@@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> {
Tensor ordered_h0; Tensor ordered_h0;
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order, ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true); &ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight, blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), &gate_t, &gate_t, static_cast<T>(1.0));
static_cast<T>(1.0));
} }
lstm_value.gate_value = gate_t.data<T>(); lstm_value.gate_value = gate_t.data<T>();
...@@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) { for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
...@@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end); auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true, blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
static_cast<T>(1.0), &pre_hidden_g, &pre_hidden_g, static_cast<T>(1.0));
static_cast<T>(1.0));
if (weight_g) { if (weight_g) {
/* backward weight */ /* backward weight */
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end); auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g, blas.MatMul(pre_hidden, true, gate_g, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), weight_g, weight_g, static_cast<T>(1.0));
static_cast<T>(1.0));
} }
} else { } else {
if (h0 && weight_g) { if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order, ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true); &ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g, blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), weight_g, weight_g, static_cast<T>(1.0));
static_cast<T>(1.0));
} }
if (h0 && h0_g) { if (h0 && h0_g) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace()); ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
true, static_cast<T>(1.0), &ordered_h0_g, static_cast<T>(0.0));
&ordered_h0_g, static_cast<T>(0.0));
} }
} }
} }
......
...@@ -14,15 +14,14 @@ limitations under the License. */ ...@@ -14,15 +14,14 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -143,7 +142,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -143,7 +142,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto proj_act = math::detail::GetActivationType( auto proj_act = math::detail::GetActivationType(
ctx.Attr<std::string>("proj_activation")); ctx.Attr<std::string>("proj_activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
...@@ -160,9 +159,8 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -160,9 +159,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end); auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj_t, false, *weight, blas.MatMul(pre_proj_t, false, *weight, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), &gate_t, &gate_t, static_cast<T>(1.0));
static_cast<T>(1.0));
} else if (hidden_t0) { } else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say // If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped. // the H0 is zeros, the calculation W_h * H0 will be skiped.
...@@ -176,16 +174,14 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -176,16 +174,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
ordered_proj0->mutable_data<T>(ctx.GetPlace()); ordered_proj0->mutable_data<T>(ctx.GetPlace());
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order, ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true); &ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0),
*proj_weight, false, static_cast<T>(1.0), ordered_proj0, static_cast<T>(0.0));
ordered_proj0, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0); auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev); ActCompute(cell_act, place, proj0_dev, proj0_dev);
} }
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, false, blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
*weight, false, static_cast<T>(1.0), &gate_t, static_cast<T>(1.0));
&gate_t, static_cast<T>(1.0));
} }
lstmp_value.gate_value = gate_t.data<T>(); lstmp_value.gate_value = gate_t.data<T>();
...@@ -196,9 +192,8 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -196,9 +192,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act, device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
cell_act, cand_act); cell_act, cand_act);
lstmp_value.prev_state_value = lstmp_value.state_value; lstmp_value.prev_state_value = lstmp_value.state_value;
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight, blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), &proj_t, &proj_t, static_cast<T>(0.0));
static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t); auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev); ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
...@@ -361,6 +356,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -361,6 +356,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) { for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
...@@ -375,15 +371,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -375,15 +371,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
} }
/* hidden state backwarad */ /* hidden state backwarad */
Tensor out_g = batch_hidden_g.Slice(bstart, bend); Tensor out_g = batch_hidden_g.Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight, blas.MatMul(proj_g, false, *proj_weight, true, static_cast<T>(1.0),
true, static_cast<T>(1.0), &out_g, &out_g, static_cast<T>(0.0));
static_cast<T>(0.0));
/* projection weight backward*/ /* projection weight backward*/
if (proj_weight_g) { if (proj_weight_g) {
Tensor hidden_t = batch_hidden->Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g, blas.MatMul(hidden_t, true, proj_g, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), proj_weight_g, static_cast<T>(1.0));
proj_weight_g, static_cast<T>(1.0));
} }
Tensor gate = batch_gate->Slice(bstart, bend); Tensor gate = batch_gate->Slice(bstart, bend);
...@@ -419,24 +413,21 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -419,24 +413,21 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end); auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true, blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
static_cast<T>(1.0), &pre_proj_g, &pre_proj_g, static_cast<T>(1.0));
static_cast<T>(1.0));
if (weight_g) { if (weight_g) {
/* weight backward*/ /* weight backward*/
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g, blas.MatMul(pre_proj, true, gate_g, false, static_cast<T>(1.0),
false, static_cast<T>(1.0), weight_g, weight_g, static_cast<T>(1.0));
static_cast<T>(1.0));
} }
} else { } else {
if (h0 && weight_g) { if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order, ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true); &ordered_h0, true);
if (weight_g) { if (weight_g) {
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, true, blas.MatMul(*ordered_proj0, true, gate_g, false,
gate_g, false, static_cast<T>(1.0), static_cast<T>(1.0), weight_g, static_cast<T>(1.0));
weight_g, static_cast<T>(1.0));
} }
} }
if (h0 && (h0_g || proj_weight_g)) { if (h0 && (h0_g || proj_weight_g)) {
...@@ -444,9 +435,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -444,9 +435,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor proj0_g; Tensor proj0_g;
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]}); proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
proj0_g.mutable_data<T>(ctx.GetPlace()); proj0_g.mutable_data<T>(ctx.GetPlace());
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
true, static_cast<T>(1.0), &proj0_g, &proj0_g, static_cast<T>(0.0));
static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0); auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g); auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
...@@ -454,14 +444,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -454,14 +444,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
proj0_g_dev); proj0_g_dev);
} }
if (h0_g) { if (h0_g) {
math::matmul<DeviceContext, T>( blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
device_ctx, proj0_g, false, *proj_weight, true, &ordered_h0_g, static_cast<T>(0.0));
static_cast<T>(1.0), &ordered_h0_g, static_cast<T>(0.0));
} }
if (proj_weight_g) { if (proj_weight_g) {
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
proj0_g, false, static_cast<T>(1.0), proj_weight_g, static_cast<T>(1.0));
proj_weight_g, static_cast<T>(1.0));
} }
} }
} }
......
...@@ -41,7 +41,8 @@ math_library(depthwise_conv) ...@@ -41,7 +41,8 @@ math_library(depthwise_conv)
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(im2col) math_library(im2col)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
math_library(math_function DEPS cblas) cc_library(blas SRCS blas.cc DEPS cblas framework_proto)
math_library(math_function DEPS blas)
math_library(maxouting) math_library(maxouting)
math_library(pooling) math_library(pooling)
math_library(selected_rows_functor DEPS selected_rows math_function) math_library(selected_rows_functor DEPS selected_rows math_function)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
namespace math {
// Do nothing. Blas is a header only library.
} // namespace math
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h> // NOLINT
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext>
class Blas {
public:
explicit Blas(const DeviceContext& context) : context_(context) {}
template <typename T>
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T* A, const T* B, T beta, T* C) const;
template <typename T>
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha,
framework::Tensor* mat_out, T beta) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b,
framework::Tensor* mat_out) const {
MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
static_cast<T>(0.0));
}
template <typename T>
void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
framework::Tensor* mat_out) const {
this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
}
template <typename T>
void AXPY(int n, T alpha, const T* x, T* y) const;
template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;
private:
const DeviceContext& context_;
};
template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
public:
using Blas<DeviceContext>::Blas;
template <typename... ARGS>
void GEMM(ARGS... args) const {
Base()->template GEMM<T>(args...);
}
template <typename... ARGS>
void MatMul(ARGS... args) const {
Base()->template MatMul<T>(args...);
}
template <typename... ARGS>
void AXPY(ARGS... args) const {
Base()->template AXPY<T>(args...);
}
template <typename... ARGS>
void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...);
}
template <typename... ARGS>
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...);
}
private:
const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this);
}
};
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
const framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
}
} // namespace math
} // namespace operators
} // namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
...@@ -30,6 +30,25 @@ struct CUBlas<float> { ...@@ -30,6 +30,25 @@ struct CUBlas<float> {
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...)); PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
} }
template <typename... ARGS>
static void AXPY(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...));
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...));
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...));
#else
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
}; };
template <> template <>
...@@ -38,6 +57,25 @@ struct CUBlas<double> { ...@@ -38,6 +57,25 @@ struct CUBlas<double> {
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...)); PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
} }
template <typename... ARGS>
static void AXPY(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...));
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...));
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...));
#else
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
}; };
template <> template <>
...@@ -57,16 +95,23 @@ struct CUBlas<platform::float16> { ...@@ -57,16 +95,23 @@ struct CUBlas<platform::float16> {
reinterpret_cast<const __half *>(beta), reinterpret_cast<const __half *>(beta),
reinterpret_cast<__half *>(C), ldc)); reinterpret_cast<__half *>(C), ldc));
} }
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...));
#else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
}; };
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA, void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, CBLAS_TRANSPOSE transB, int M,
const int M, const int N, int N, int K, T alpha, const T *A,
const int K, const T alpha, const T *B, T beta, T *C) const {
const T *A, const T *B,
const T beta, T *C) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -83,10 +128,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA, ...@@ -83,10 +128,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
template <> template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
const int N, const int K, const platform::float16 alpha, platform::float16 alpha, const platform::float16 *A,
const platform::float16 *A, const platform::float16 *B, const platform::float16 *B, platform::float16 beta,
const platform::float16 beta, platform::float16 *C) const { platform::float16 *C) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -134,18 +179,58 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -134,18 +179,58 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM( void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
const bool transA, const bool transB, const int M, const int N, const int K, int N, int K, T alpha, const T *A,
const T alpha, const T *A, const int lda, const T *B, const int ldb, int lda, const T *B, int ldb,
const T beta, T *C, const int ldc) const { T beta, T *C, int ldc) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc); B, ldb, A, lda, &beta, C, ldc);
} }
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const {
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T alpha, const T *A, const T *B,
T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
&beta, C, 1);
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
CUBlas<T>::GEMM_BATCH(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
&alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc,
strideC, batchCount);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -28,6 +28,23 @@ struct CBlas<float> { ...@@ -28,6 +28,23 @@ struct CBlas<float> {
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
cblas_sgemm(args...); cblas_sgemm(args...);
} }
template <typename... ARGS>
static void AXPY(ARGS... args) {
cblas_saxpy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
cblas_sgemv(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
cblas_sgemm_batch(args...);
}
#endif
}; };
template <> template <>
...@@ -36,21 +53,41 @@ struct CBlas<double> { ...@@ -36,21 +53,41 @@ struct CBlas<double> {
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
cblas_dgemm(args...); cblas_dgemm(args...);
} }
template <typename... ARGS>
static void AXPY(ARGS... args) {
cblas_daxpy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
cblas_dgemv(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
cblas_dgemm_batch(args...);
}
#endif
}; };
template <> template <>
struct CBlas<platform::float16> { struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
}
#endif
}; };
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA, void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, CBLAS_TRANSPOSE transB, int M,
const int M, const int N, int N, int K, T alpha, const T *A,
const int K, const T alpha, const T *B, T beta, T *C) const {
const T *A, const T *B,
const T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -60,15 +97,89 @@ void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA, ...@@ -60,15 +97,89 @@ void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM( void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
const bool transA, const bool transB, const int M, const int N, const int K, int N, int K, T alpha, const T *A,
const T alpha, const T *A, const int lda, const T *B, const int ldb, int lda, const T *B, int ldb,
const T beta, T *C, const int ldc) const { T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc); lda, B, ldb, beta, C, ldc);
} }
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
const framework::Tensor &mat_b, bool trans_b,
T alpha, framework::Tensor *mat_out,
T beta) const {
auto dim_a = mat_a.dims();
auto dim_b = mat_b.dims();
auto dim_out = mat_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(),
"The places of matrices must be same");
int M = dim_out[0];
int N = dim_out[1];
int K = !trans_a ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans;
this->GEMM(transA, transB, M, N, K, alpha, mat_a.data<T>(), mat_b.data<T>(),
beta, mat_out->data<T>());
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const {
CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
const T *A, const T *B, T beta,
T *C) const {
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBlas<T>::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB) const {
#ifdef PADDLE_WITH_MKLML
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
const float *Ak = &A[k * strideA];
const float *Bk = &B[k * strideB];
float *Ck = &C[k * M * N];
this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck);
}
#endif
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -17,8 +17,8 @@ limitations under the License. */ ...@@ -17,8 +17,8 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -211,6 +211,7 @@ class ContextProjectGradFunctor { ...@@ -211,6 +211,7 @@ class ContextProjectGradFunctor {
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
auto blas = math::GetBlas<DeviceContext, T>(context);
if (input_grad) { if (input_grad) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
...@@ -262,8 +263,8 @@ class ContextProjectGradFunctor { ...@@ -262,8 +263,8 @@ class ContextProjectGradFunctor {
Tensor out_t_sub = out_t.Slice(k * context_length, Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size); k * context_length + padding_size);
Tensor w_sub = padding_data->Slice(k, k + padding_size); Tensor w_sub = padding_data->Slice(k, k + padding_size);
axpy<DeviceContext, T>(context, w_sub.numel(), static_cast<T>(1), blas.AXPY(w_sub.numel(), static_cast<T>(1), out_t_sub.data<T>(),
out_t_sub.data<T>(), w_sub.data<T>()); w_sub.data<T>());
} }
} }
if (down_pad > 0) { if (down_pad > 0) {
...@@ -294,8 +295,8 @@ class ContextProjectGradFunctor { ...@@ -294,8 +295,8 @@ class ContextProjectGradFunctor {
(down_pad_begin_row + t) * context_length); (down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data->Slice( Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size); up_pad + padding_idx, up_pad + padding_idx + padding_size);
axpy<DeviceContext, T>(context, w_sub.numel(), static_cast<T>(1), blas.AXPY(w_sub.numel(), static_cast<T>(1), out_t_sub.data<T>(),
out_t_sub.data<T>(), w_sub.data<T>()); w_sub.data<T>());
} }
} }
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
......
...@@ -10,9 +10,9 @@ See the License for the specific language governing permissions and ...@@ -10,9 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -10,10 +10,10 @@ See the License for the specific language governing permissions and ...@@ -10,10 +10,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/fluid/platform/device_context.h> #include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -24,200 +24,6 @@ namespace math { ...@@ -24,200 +24,6 @@ namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
template <>
void matmul<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
framework::Tensor* matrix_out, float16 beta) {
PADDLE_THROW("float16 matmul not supported on CPU");
}
template <>
void matmul<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CPUDeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
template <>
void matmul<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, double alpha,
framework::Tensor* matrix_out, double beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CPUDeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
template <>
void batched_gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
PADDLE_THROW("float16 batched_gemm not supported on CPU");
}
#ifdef PADDLE_WITH_MKLML
// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize.
template <>
void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const float*>(batchCount);
auto b_array = std::vector<const float*>(batchCount);
auto c_array = std::vector<float*>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
cblas_sgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
}
template <>
void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const double*>(batchCount);
auto b_array = std::vector<const double*>(batchCount);
auto c_array = std::vector<double*>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
cblas_dgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
}
#else
// The below is a naive but correct serial implementation that just loops
// over the batch dimension. This is a fallback for when the batched gemm
// functions of Intel MKL are not available. In the future, this computation
// should be parallelized.
template <>
void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) {
const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB];
float* Ck = &C[k * M * N];
Blas<platform::CPUDeviceContext>(context).GEMM(transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck);
}
}
template <>
void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) {
const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB];
double* Ck = &C[k * M * N];
Blas<platform::CPUDeviceContext>(context).GEMM(transA, transB, M, N, K,
alpha, Ak, Bk, beta, Ck);
}
}
#endif
template <>
void gemv<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const bool trans_a, const int M,
const int N, const float alpha, const float* A, const float* B,
const float beta, float* C) {
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
cblas_sgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
void gemv<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const bool trans_a, const int M,
const int N, const double alpha, const double* A, const double* B,
const double beta, double* C) {
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
void axpy<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const int n, const float alpha,
const float* x, float* y) {
cblas_saxpy(n, alpha, x, 1, y, 1);
}
template <>
void axpy<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const int n, const double alpha,
const double* x, double* y) {
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template struct SetConstant<platform::CPUDeviceContext, platform::float16>; template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, float>; template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>; template struct SetConstant<platform::CPUDeviceContext, double>;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -25,223 +26,6 @@ namespace math { ...@@ -25,223 +26,6 @@ namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
template <>
void matmul<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
framework::Tensor* matrix_out, float16 beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}
template <>
void matmul<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* matrix_out, float beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>());
}
template <>
void matmul<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context,
const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, double alpha,
framework::Tensor* matrix_out, double beta) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Blas<platform::CUDADeviceContext>(context).GEMM(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>());
}
template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
template <>
void batched_gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
template <>
void batched_gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
template <>
void gemv<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const bool trans_a, const int M,
const int N, const float alpha, const float* A, const float* B,
const float beta, float* C) {
cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE(platform::dynload::cublasSgemv(context.cublas_handle(),
cuTransA, N, M, &alpha, A, N, B,
1, &beta, C, 1));
}
template <>
void gemv<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const bool trans_a, const int M,
const int N, const double alpha, const double* A, const double* B,
const double beta, double* C) {
cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE(platform::dynload::cublasDgemv(context.cublas_handle(),
cuTransA, N, M, &alpha, A, N, B,
1, &beta, C, 1));
}
template <>
void axpy<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const int n, const float alpha,
const float* x, float* y) {
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(context.cublas_handle(), n,
&alpha, x, 1, y, 1));
}
template <>
void axpy<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const int n, const double alpha,
const double* x, double* y) {
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(context.cublas_handle(), n,
&alpha, x, 1, y, 1));
}
template struct SetConstant<platform::CUDADeviceContext, platform::float16>; template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>; template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>; template struct SetConstant<platform::CUDADeviceContext, double>;
...@@ -333,10 +117,9 @@ void ColwiseSum<platform::CUDADeviceContext, double>::operator()( ...@@ -333,10 +117,9 @@ void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
one.mutable_data<double>({in_dims[0]}, context.GetPlace()); one.mutable_data<double>({in_dims[0]}, context.GetPlace());
SetConstant<platform::CUDADeviceContext, double> set; SetConstant<platform::CUDADeviceContext, double> set;
set(context, &one, static_cast<double>(1.0)); set(context, &one, static_cast<double>(1.0));
gemv<platform::CUDADeviceContext, double>( GetBlas<platform::CUDADeviceContext, double>(context).GEMV(
context, true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]), true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]), 1.0,
1.0, input.data<double>(), one.data<double>(), 0.0, input.data<double>(), one.data<double>(), 0.0, vector->data<double>());
vector->data<double>());
} }
template struct RowwiseSum<platform::CUDADeviceContext, float>; template struct RowwiseSum<platform::CUDADeviceContext, float>;
...@@ -355,10 +138,9 @@ void RowwiseSum<platform::CUDADeviceContext, double>::operator()( ...@@ -355,10 +138,9 @@ void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
one.mutable_data<double>({size}, context.GetPlace()); one.mutable_data<double>({size}, context.GetPlace());
SetConstant<platform::CUDADeviceContext, double> set; SetConstant<platform::CUDADeviceContext, double> set;
set(context, &one, static_cast<double>(1.0)); set(context, &one, static_cast<double>(1.0));
gemv<platform::CUDADeviceContext, double>( GetBlas<platform::CUDADeviceContext, double>(context).GEMV(
context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]), true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]), 1.0,
1.0, one.data<double>(), input.data<double>(), 0.0, one.data<double>(), input.data<double>(), 0.0, vector->data<double>());
vector->data<double>());
} }
template struct RowwiseMean<platform::CUDADeviceContext, float>; template struct RowwiseMean<platform::CUDADeviceContext, float>;
......
...@@ -51,78 +51,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -51,78 +51,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// Support continuous memory now
// If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N, matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename DeviceContext>
class Blas {
public:
explicit Blas(const DeviceContext& context) : context_(context) {}
template <typename T>
void GEMM(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C) const;
template <typename T>
void GEMM(const bool transA, const bool transB, const int M, const int N,
const int K, const T alpha, const T* A, const int lda, const T* B,
const int ldb, const T beta, T* C, const int ldc) const;
private:
const DeviceContext& context_;
};
template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
public:
using Blas<DeviceContext>::Blas;
template <typename... ARGS>
void GEMM(ARGS... args) const {
static_cast<const Blas<DeviceContext>*>(this)->template GEMM<T>(args...);
}
};
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
const framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
}
// matrix multiply with continuous memory
template <typename DeviceContext, typename T>
void matmul(const DeviceContext& context, const framework::Tensor& matrix_a,
bool trans_a, const framework::Tensor& matrix_b, bool trans_b,
T alpha, framework::Tensor* matrix_out, T beta);
// Batched gemm
template <typename DeviceContext, typename T>
void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N,
const int K, const T alpha, const T* A, const T* B,
const T beta, T* C, const int batchCount,
const int64_t strideA, const int64_t strideB);
template <typename DeviceContext, typename T>
void gemv(const DeviceContext& context, const bool trans_a, const int M,
const int N, const T alpha, const T* A, const T* B, const T beta,
T* C);
template <typename DeviceContext, typename T>
void axpy(const DeviceContext& context, const int n, const T alpha, const T* x,
T* y);
template <typename DeviceContext, typename T, int Rank> template <typename DeviceContext, typename T, int Rank>
struct Transpose { struct Transpose {
void operator()(const DeviceContext& context, const framework::Tensor& in, void operator()(const DeviceContext& context, const framework::Tensor& in,
...@@ -169,8 +97,3 @@ struct RowwiseMean { ...@@ -169,8 +97,3 @@ struct RowwiseMean {
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/blas.h"
template <typename T> template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T> inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
...@@ -129,9 +130,8 @@ void GemvTest(int m, int n, bool trans) { ...@@ -129,9 +130,8 @@ void GemvTest(int m, int n, bool trans) {
} }
paddle::platform::CPUDeviceContext context(*cpu_place); paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemv<paddle::platform::CPUDeviceContext, T>( GetBlas<T>(context).GEMV(trans, static_cast<int>(m), static_cast<int>(n), 1.,
context, trans, static_cast<int>(m), static_cast<int>(n), 1., data_a, data_a, data_b, 0., data_c);
data_b, 0., data_c);
if (!trans) { if (!trans) {
for (int i = 0; i < m; ++i) { for (int i = 0; i < m; ++i) {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -23,6 +24,13 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, ...@@ -23,6 +24,13 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
} }
} }
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, notrans_mul_trans_fp32) { TEST(math_function, notrans_mul_trans_fp32) {
paddle::framework::Tensor input1; paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu; paddle::framework::Tensor input1_gpu;
...@@ -42,9 +50,8 @@ TEST(math_function, notrans_mul_trans_fp32) { ...@@ -42,9 +50,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu); paddle::framework::TensorCopySync(input1, gpu_place, &input2_gpu);
out_gpu.mutable_data<float>({2, 2}, gpu_place); out_gpu.mutable_data<float>({2, 2}, gpu_place);
GetBlas<float>(context).MatMul(input1_gpu, false, input2_gpu, true, 1,
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>( &out_gpu, 0);
context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
...@@ -81,10 +88,9 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -81,10 +88,9 @@ TEST(math_function, notrans_mul_trans_fp16) {
out_gpu.mutable_data<paddle::platform::float16>({2, 2}, gpu_place); out_gpu.mutable_data<paddle::platform::float16>({2, 2}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, GetBlas<paddle::platform::float16>(context).MatMul(
paddle::platform::float16>( input1_gpu, false, input2_gpu, true, paddle::platform::float16(1),
context, input1_gpu, false, input2_gpu, true, &out_gpu, paddle::platform::float16(0));
paddle::platform::float16(1), &out_gpu, paddle::platform::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
...@@ -116,8 +122,8 @@ TEST(math_function, trans_mul_notrans_fp32) { ...@@ -116,8 +122,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
out_gpu.mutable_data<float>({3, 3}, gpu_place); out_gpu.mutable_data<float>({3, 3}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>( GetBlas<float>(context).MatMul(input1_gpu, true, input2_gpu, false, 1,
context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); &out_gpu, 0);
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
...@@ -159,10 +165,9 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -159,10 +165,9 @@ TEST(math_function, trans_mul_notrans_fp16) {
out_gpu.mutable_data<paddle::platform::float16>({3, 3}, gpu_place); out_gpu.mutable_data<paddle::platform::float16>({3, 3}, gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, GetBlas<paddle::platform::float16>(context).MatMul(
paddle::platform::float16>( input1_gpu, true, input2_gpu, false, paddle::platform::float16(1),
context, input1_gpu, true, input2_gpu, false, &out_gpu, paddle::platform::float16(0));
paddle::platform::float16(1), &out_gpu, paddle::platform::float16(0));
paddle::framework::TensorCopySync(out_gpu, cpu_place, &out); paddle::framework::TensorCopySync(out_gpu, cpu_place, &out);
...@@ -179,13 +184,6 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -179,13 +184,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ(static_cast<float>(out_ptr[8]), 29); EXPECT_EQ(static_cast<float>(out_ptr[8]), 29);
} }
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cublas_fp32) { TEST(math_function, gemm_notrans_cublas_fp32) {
paddle::framework::Tensor input1; paddle::framework::Tensor input1;
paddle::framework::Tensor input2; paddle::framework::Tensor input2;
...@@ -437,9 +435,8 @@ void GemvTest(int m, int n, bool trans) { ...@@ -437,9 +435,8 @@ void GemvTest(int m, int n, bool trans) {
paddle::framework::TensorCopySync(mat_a, gpu_place, &g_mat_a); paddle::framework::TensorCopySync(mat_a, gpu_place, &g_mat_a);
paddle::framework::TensorCopySync(vec_b, gpu_place, &g_vec_b); paddle::framework::TensorCopySync(vec_b, gpu_place, &g_vec_b);
paddle::operators::math::gemv<paddle::platform::CUDADeviceContext, T>( GetBlas<T>(context).GEMV(trans, static_cast<int>(m), static_cast<int>(n), 1.,
context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a, g_data_a, g_data_b, 0., g_data_c);
g_data_b, 0., g_data_c);
paddle::framework::TensorCopySync(g_vec_c, cpu_place, &vec_c); paddle::framework::TensorCopySync(g_vec_c, cpu_place, &vec_c);
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -129,16 +129,17 @@ class MatMulFunctor { ...@@ -129,16 +129,17 @@ class MatMulFunctor {
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
auto blas = GetBlas<DeviceContext, T>(context);
if (!batchCount) { if (!batchCount) {
// regular matrix multiplication // regular matrix multiplication
Blas<DeviceContext>(context).GEMM(transA, transB, M, N, kA, alpha, blas.GEMM(transA, transB, M, N, kA, alpha, a.data<T>(), b.data<T>(), beta,
a.data<T>(), b.data<T>(), beta, out->data<T>());
out->data<T>());
} else { } else {
// batched matrix multiplication // batched matrix multiplication
batched_gemm<DeviceContext, T>( blas.BatchedGEMM(transA, transB, M, N, kA, alpha, a.data<T>(),
context, transA, transB, M, N, kA, alpha, a.data<T>(), b.data<T>(), b.data<T>(), beta, out->data<T>(), batchCount, strideA,
beta, out->data<T>(), batchCount, strideA, strideB); strideB);
} }
} }
}; };
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -46,9 +46,10 @@ class MulKernel : public framework::OpKernel<T> { ...@@ -46,9 +46,10 @@ class MulKernel : public framework::OpKernel<T> {
if (z_dim.size() != 2) { if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
} }
math::matmul<DeviceContext, T>(
context.template device_context<DeviceContext>(), x_matrix, false, auto blas = math::GetBlas<DeviceContext, T>(context);
y_matrix, false, static_cast<T>(1), z, static_cast<T>(0));
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) { if (z_dim.size() != 2) {
z->Resize(z_dim); z->Resize(z_dim);
} }
...@@ -79,6 +80,7 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -79,6 +80,7 @@ class MulGradKernel : public framework::OpKernel<T> {
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) { if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2 Tensor dx_matrix = dx->dims().size() > 2
...@@ -86,8 +88,7 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -86,8 +88,7 @@ class MulGradKernel : public framework::OpKernel<T> {
: *dx; : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<DeviceContext, T>(dev_ctx, dout_mat, false, y_matrix, true, blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
1, &dx_matrix, 0);
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
...@@ -95,8 +96,7 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -95,8 +96,7 @@ class MulGradKernel : public framework::OpKernel<T> {
? framework::ReshapeToMatrix(*dy, y_num_col_dims) ? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy; : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K // dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<DeviceContext, T>(dev_ctx, x_matrix, true, dout_mat, false, blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
1, &dy_matrix, 0);
} }
} }
}; };
......
...@@ -58,17 +58,15 @@ class SequenceConvKernel : public framework::OpKernel<T> { ...@@ -58,17 +58,15 @@ class SequenceConvKernel : public framework::OpKernel<T> {
// Because if padding_trainable is false, padding data should be zeros. // Because if padding_trainable is false, padding data should be zeros.
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, &col, static_cast<T>(0)); set_zero(dev_ctx, &col, static_cast<T>(0));
math::ContextProjectFunctor<DeviceContext, T> seq_project_functor; math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable, seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable,
context_start, context_length, context_stride, up_pad, context_start, context_length, context_stride, up_pad,
down_pad, &col); down_pad, &col);
math::matmul<DeviceContext, T>(dev_ctx, col, false, filter, false, blas.MatMul(col, filter, out);
static_cast<T>(1.0), out,
static_cast<T>(0.0));
} }
}; };
...@@ -99,6 +97,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -99,6 +97,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// use col_shape in the im2col calculation // use col_shape in the im2col calculation
framework::DDim col_shape = {in->dims()[0], framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length}; sequence_width * context_length};
...@@ -108,8 +107,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -108,8 +107,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
// Because if padding_trainable is false, padding data should be zeros. // Because if padding_trainable is false, padding data should be zeros.
set_zero(dev_ctx, &col, static_cast<T>(0)); set_zero(dev_ctx, &col, static_cast<T>(0));
math::matmul<DeviceContext, T>(dev_ctx, *out_g, false, *filter, true, blas.MatMul(*out_g, false, *filter, true, &col);
T(1.0), &col, T(1.0));
} }
math::ContextProjectFunctor<DeviceContext, T> seq_project_functor; math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
math::ContextProjectGradFunctor<DeviceContext, T> seq_project_grad_functor; math::ContextProjectGradFunctor<DeviceContext, T> seq_project_grad_functor;
...@@ -150,8 +148,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -150,8 +148,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
context_start, context_length, context_stride, up_pad, context_start, context_length, context_stride, up_pad,
down_pad, &col); down_pad, &col);
math::matmul<DeviceContext, T>(dev_ctx, col, true, out_grad, false, blas.MatMul(col, true, out_grad, false, &filter_grad);
T(1.0), &filter_grad, T(1.0));
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册