提交 a8e18549 编写于 作者: D dangqingqing

Fix the clang format.

上级 d760b6a5
......@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace paddle {
namespace operators {
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/hostdevice.h"
#ifdef __AVX__
#include <immintrin.h>
......@@ -37,20 +37,19 @@ enum ActivationType {
kIdentity,
};
inline ActivationType GetActivationType (const std::string &type) {
inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") {
return ActivationType::kSigmoid;
} else if (type == "relu") {
return ActivationType::kReLU;
} else if (type == "tanh") {
return ActivationType::kTanh;
} else if (type == "identity") {
} else if (type == "identity" || type == "") {
return ActivationType::kIdentity;
}
PADDLE_THROW("Not support type %s.", type);
}
namespace forward {
template <typename T>
......
......@@ -26,8 +26,7 @@ namespace detail {
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size,
ActivationType active_node,
int frame_size, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
T r_value_in;
......@@ -149,8 +148,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size,
ActivationType active_node,
int frame_size, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
#ifdef __AVX__
......@@ -281,8 +279,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
......
......@@ -185,8 +185,7 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frame_size, int batch_size,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
dim3 threads;
dim3 grid;
......@@ -220,8 +219,7 @@ template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
dim3 threads;
dim3 grid;
......
......@@ -28,8 +28,8 @@ struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, cand_act,
gate_act, cell_act);
frame_size, batch_size, cand_act, gate_act,
cell_act);
}
};
......@@ -42,8 +42,8 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, cand_act,
gate_act, cell_act);
frame_size, batch_size, cand_act, gate_act,
cell_act);
}
};
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/operators/math/detail/activation_functions.h"
namespace paddle {
namespace operators {
......@@ -30,7 +30,6 @@ typedef enum {
HL_ACTIVATION_END
} activation_mode_t;
template <class T>
struct LstmMetaValue {
T *gate_value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册