提交 a3e952f4 编写于 作者: P peizhilin

add the jit back

fix compile error on windows
上级 1cc23ef6
...@@ -130,6 +130,11 @@ if (APPLE OR WIN32) ...@@ -130,6 +130,11 @@ if (APPLE OR WIN32)
"Disable MKL for building on mac and windows" FORCE) "Disable MKL for building on mac and windows" FORCE)
endif() endif()
if (WIN32)
set(WITH_AVX OFF CACHE STRING
"Disable AVX when compiling for Windows" FORCE)
endif()
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.") "A path setting third party libraries download & build directories.")
......
...@@ -84,9 +84,8 @@ function(op_library TARGET) ...@@ -84,9 +84,8 @@ function(op_library TARGET)
endif() endif()
if (WIN32) if (WIN32)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops. # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op" foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}") if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return() return()
endif() endif()
......
...@@ -70,17 +70,20 @@ int main() ...@@ -70,17 +70,20 @@ int main()
return 0; return 0;
}" AVX_FOUND) }" AVX_FOUND)
# Check AVX 2 # disable AVX2 by default on windows
set(CMAKE_REQUIRED_FLAGS ${AVX2_FLAG}) if(NOT WIN32)
set(AVX2_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) # Check AVX 2
CHECK_CXX_SOURCE_RUNS(" set(CMAKE_REQUIRED_FLAGS ${AVX2_FLAG})
#include <immintrin.h> set(AVX2_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE)
int main() CHECK_CXX_SOURCE_RUNS("
{ #include <immintrin.h>
__m256i a = _mm256_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4); int main()
__m256i result = _mm256_abs_epi32 (a); {
return 0; __m256i a = _mm256_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4);
}" AVX2_FOUND) __m256i result = _mm256_abs_epi32 (a);
return 0;
}" AVX2_FOUND)
endif(NOT WIN32)
# Check AVX512F # Check AVX512F
set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG}) set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG})
......
...@@ -22,9 +22,7 @@ if(WITH_DISTRIBUTE) ...@@ -22,9 +22,7 @@ if(WITH_DISTRIBUTE)
add_subdirectory(distributed_ops) add_subdirectory(distributed_ops)
endif() endif()
if (NOT WIN32) add_subdirectory(reader)
add_subdirectory(reader)
endif()
if (NOT WIN32) if (NOT WIN32)
add_subdirectory(nccl) add_subdirectory(nccl)
...@@ -49,9 +47,10 @@ endif() ...@@ -49,9 +47,10 @@ endif()
set(COMMON_OP_DEPS "") set(COMMON_OP_DEPS "")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor dynload_warpctc sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor sequence_padding sequence_scale cos_sim_functor memory concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} lstm_compute matrix_bit_code sequence2batch gru_compute activation_functions jit_kernel)
if (NOT WIN32) if (NOT WIN32)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
endif() endif()
if (WITH_GPU) if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub)
......
...@@ -111,7 +111,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -111,7 +111,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto pre_out_mat = EigenMatrix<T>::From(*pre_out); auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad); auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad); auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}}); Eigen::array<int, 2> bcast{1, static_cast<int>(pre_out_grad.dims()[1])};
// softrelu derivative // softrelu derivative
pre_out_grad_mat.device(place) = pre_out_grad_mat.device(place) =
......
if (NOT WIN32) add_subdirectory(detail)
add_subdirectory(detail)
endif(NOT WIN32)
function(math_library TARGET) function(math_library TARGET)
# math_library is a function to create math library. # math_library is a function to create math library.
...@@ -43,10 +41,8 @@ math_library(depthwise_conv) ...@@ -43,10 +41,8 @@ math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
math_library(sampler) math_library(sampler)
if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function)
math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions)
math_library(lstm_compute DEPS activation_functions)
endif (NOT WIN32)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
math_library(math_function DEPS blas) math_library(math_function DEPS blas)
...@@ -58,9 +54,9 @@ math_library(sequence_padding) ...@@ -58,9 +54,9 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function) math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
if (NOT WIN32)
math_library(matrix_bit_code) math_library(matrix_bit_code)
endif (NOT WIN32)
math_library(unpooling) math_library(unpooling)
math_library(vol2col) math_library(vol2col)
...@@ -76,13 +72,12 @@ if(WITH_GPU) ...@@ -76,13 +72,12 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
if (NOT WIN32)
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc) set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK) if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak) list(APPEND JIT_KERNEL_DEPS xbyak)
endif() endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS}) cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
endif (NOT WIN32)
...@@ -67,7 +67,7 @@ inline constexpr size_t FindLastSet(size_t x) { ...@@ -67,7 +67,7 @@ inline constexpr size_t FindLastSet(size_t x) {
: (std::is_same<size_t, unsigned long>::value // NOLINT : (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0) ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0)); : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
#else #else
// windows don't have built-in clz, ctz function // windows don't have built-in clz, ctz function
template <typename T> template <typename T>
...@@ -92,7 +92,6 @@ inline int clz(const T& value) { ...@@ -92,7 +92,6 @@ inline int clz(const T& value) {
inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); }
#endif // !_WIN32 #endif // !_WIN32
}
struct SimpleCode { struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
......
...@@ -170,12 +170,6 @@ __all__ = [ ...@@ -170,12 +170,6 @@ __all__ = [
'bilinear_tensor_product', 'bilinear_tensor_product',
] ]
# To avoid the api checker complains
if os.name == 'nt':
__all__.remove('dynamic_lstm')
__all__.remove('crf_decoding')
__all__.remove('roi_pool')
def fc(input, def fc(input,
size, size,
...@@ -349,128 +343,126 @@ def embedding(input, ...@@ -349,128 +343,126 @@ def embedding(input,
return tmp return tmp
if os.name != 'nt': @templatedoc(op_type="lstm")
def dynamic_lstm(input,
size,
h_0=None,
c_0=None,
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
dtype='float32',
name=None):
"""
${comment}
@templatedoc(op_type="lstm") Args:
def dynamic_lstm(input, input (Variable): ${input_comment}
size, size (int): 4 * hidden size.
h_0=None, h_0(Variable): The initial hidden state is an optional input, default is zero.
c_0=None, This is a tensor with shape (N x D), where N is the
param_attr=None, batch size and D is the hidden size.
bias_attr=None, c_0(Variable): The initial cell state is an optional input, default is zero.
use_peepholes=True, This is a tensor with shape (N x D), where N is the
is_reverse=False, batch size. `h_0` and `c_0` can be NULL but only at the same time.
gate_activation='sigmoid', param_attr(ParamAttr|None): The parameter attribute for the learnable
cell_activation='tanh', hidden-hidden weights.
candidate_activation='tanh',
dtype='float32',
name=None):
"""
${comment}
Args:
input (Variable): ${input_comment}
size (int): 4 * hidden size.
h_0(Variable): The initial hidden state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size and D is the hidden size.
c_0(Variable): The initial cell state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weights.
- Weights = {:math:`W_{ch}, W_{ih}, \
W_{fh}, W_{oh}`}
- The shape is (D x 4D), where D is the hidden
size.
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the
parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
setting `use_peepholes` to `True`.
1. `use_peepholes = False`
- Biases = {:math:`b_c, b_i, b_f, b_o`}.
- The shape is (1 x 4D).
2. `use_peepholes = True`
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is (1 x 7D).
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as bias_attr.
If the Initializer of the bias_attr is not set,
the bias is initialized zero. Default: None.
use_peepholes (bool): ${use_peepholes_comment}
is_reverse (bool): ${is_reverse_comment}
gate_activation (str): ${gate_activation_comment}
cell_activation (str): ${cell_activation_comment}
candidate_activation (str): ${candidate_activation_comment}
dtype (str): Data type. Choices = ["float32", "float64"], default "float32".
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
tuple: The hidden state, and cell state of LSTM. The shape of both \
is (T x D), and lod is the same with the `input`.
Examples:
.. code-block:: python
hidden_dim = 512
forward_proj = fluid.layers.fc(input=input_seq, size=hidden_dim * 4,
bias_attr=False)
forward, _ = fluid.layers.dynamic_lstm(
input=forward_proj, size=hidden_dim * 4, use_peepholes=False)
"""
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
helper = LayerHelper('lstm', **locals())
size = size // 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
hidden = helper.create_variable_for_type_inference(dtype) - Weights = {:math:`W_{ch}, W_{ih}, \
cell = helper.create_variable_for_type_inference(dtype) W_{fh}, W_{oh}`}
batch_gate = helper.create_variable_for_type_inference(dtype) - The shape is (D x 4D), where D is the hidden
batch_cell_pre_act = helper.create_variable_for_type_inference(dtype) size.
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, size), \
'The shape of h0 should be (batch_size, %d)' % size
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), \
'The shape of c0 should be (batch_size, %d)' % size
inputs['C0'] = c_0
helper.append_op( If it is set to None or one attribute of ParamAttr,
type='lstm', dynamic_lstm will create ParamAttr as param_attr.
inputs=inputs, If the Initializer of the param_attr is not set, the
outputs={ parameter is initialized with Xavier. Default: None.
'Hidden': hidden, bias_attr (ParamAttr|None): The bias attribute for the learnable bias
'Cell': cell, weights, which contains two parts, input-hidden
'BatchGate': batch_gate, bias weights and peephole connections weights if
'BatchCellPreAct': batch_cell_pre_act setting `use_peepholes` to `True`.
},
attrs={ 1. `use_peepholes = False`
'use_peepholes': use_peepholes, - Biases = {:math:`b_c, b_i, b_f, b_o`}.
'is_reverse': is_reverse, - The shape is (1 x 4D).
'gate_activation': gate_activation, 2. `use_peepholes = True`
'cell_activation': cell_activation, - Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
'candidate_activation': candidate_activation W_{fc}, W_{oc}`}.
}) - The shape is (1 x 7D).
return hidden, cell
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as bias_attr.
If the Initializer of the bias_attr is not set,
the bias is initialized zero. Default: None.
use_peepholes (bool): ${use_peepholes_comment}
is_reverse (bool): ${is_reverse_comment}
gate_activation (str): ${gate_activation_comment}
cell_activation (str): ${cell_activation_comment}
candidate_activation (str): ${candidate_activation_comment}
dtype (str): Data type. Choices = ["float32", "float64"], default "float32".
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
tuple: The hidden state, and cell state of LSTM. The shape of both \
is (T x D), and lod is the same with the `input`.
Examples:
.. code-block:: python
hidden_dim = 512
forward_proj = fluid.layers.fc(input=input_seq, size=hidden_dim * 4,
bias_attr=False)
forward, _ = fluid.layers.dynamic_lstm(
input=forward_proj, size=hidden_dim * 4, use_peepholes=False)
"""
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
helper = LayerHelper('lstm', **locals())
size = size // 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
hidden = helper.create_variable_for_type_inference(dtype)
cell = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_variable_for_type_inference(dtype)
batch_cell_pre_act = helper.create_variable_for_type_inference(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, size), \
'The shape of h0 should be (batch_size, %d)' % size
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), \
'The shape of c0 should be (batch_size, %d)' % size
inputs['C0'] = c_0
helper.append_op(
type='lstm',
inputs=inputs,
outputs={
'Hidden': hidden,
'Cell': cell,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation
})
return hidden, cell
def dynamic_lstmp(input, def dynamic_lstmp(input,
...@@ -969,43 +961,39 @@ def linear_chain_crf(input, label, param_attr=None): ...@@ -969,43 +961,39 @@ def linear_chain_crf(input, label, param_attr=None):
return log_likelihood return log_likelihood
if os.name != 'nt': @templatedoc()
def crf_decoding(input, param_attr, label=None):
@templatedoc() """
def crf_decoding(input, param_attr, label=None): ${comment}
"""
${comment}
Args: Args:
input(${emission_type}): ${emission_comment} input(${emission_type}): ${emission_comment}
param_attr(ParamAttr): The parameter attribute for training. param_attr(ParamAttr): The parameter attribute for training.
label(${label_type}): ${label_comment} label(${label_type}): ${label_comment}
Returns: Returns:
Variable: ${viterbi_path_comment} Variable: ${viterbi_path_comment}
Examples: Examples:
.. code-block:: python .. code-block:: python
crf_decode = layers.crf_decoding( crf_decode = layers.crf_decoding(
input=hidden, param_attr=ParamAttr(name="crfw")) input=hidden, param_attr=ParamAttr(name="crfw"))
""" """
helper = LayerHelper('crf_decoding', **locals()) helper = LayerHelper('crf_decoding', **locals())
transition = helper.get_parameter(param_attr.name) transition = helper.get_parameter(param_attr.name)
viterbi_path = helper.create_variable_for_type_inference( viterbi_path = helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='crf_decoding', type='crf_decoding',
inputs={ inputs={"Emission": [input],
"Emission": [input],
"Transition": transition, "Transition": transition,
"Label": label "Label": label},
}, outputs={"ViterbiPath": [viterbi_path]})
outputs={"ViterbiPath": [viterbi_path]})
return viterbi_path return viterbi_path
@templatedoc() @templatedoc()
...@@ -5599,48 +5587,42 @@ def label_smooth(label, ...@@ -5599,48 +5587,42 @@ def label_smooth(label,
return smooth_label return smooth_label
if os.name != 'nt': @templatedoc()
def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
@templatedoc() """
def roi_pool(input, ${comment}
rois,
pooled_height=1, Args:
pooled_width=1, input (Variable): ${x_comment}
spatial_scale=1.0): rois (Variable): ROIs (Regions of Interest) to pool over.
""" pooled_height (integer): ${pooled_height_comment} Default: 1
${comment} pooled_width (integer): ${pooled_width_comment} Default: 1
spatial_scale (float): ${spatial_scale_comment} Default: 1.0
Args:
input (Variable): ${x_comment} Returns:
rois (Variable): ROIs (Regions of Interest) to pool over. Variable: ${out_comment}.
pooled_height (integer): ${pooled_height_comment} Default: 1
pooled_width (integer): ${pooled_width_comment} Default: 1 Examples:
spatial_scale (float): ${spatial_scale_comment} Default: 1.0 .. code-block:: python
Returns: pool_out = fluid.layers.roi_pool(input=x, rois=rois, 7, 7, 1.0)
Variable: ${out_comment}. """
helper = LayerHelper('roi_pool', **locals())
Examples: dtype = helper.input_dtype()
.. code-block:: python pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_variable_for_type_inference(dtype='int32')
pool_out = fluid.layers.roi_pool(input=x, rois=rois, 7, 7, 1.0) helper.append_op(
""" type="roi_pool",
helper = LayerHelper('roi_pool', **locals()) inputs={"X": input,
dtype = helper.input_dtype() "ROIs": rois},
pool_out = helper.create_variable_for_type_inference(dtype) outputs={"Out": pool_out,
argmaxes = helper.create_variable_for_type_inference(dtype='int32') "Argmax": argmaxes},
helper.append_op( attrs={
type="roi_pool", "pooled_height": pooled_height,
inputs={"X": input, "pooled_width": pooled_width,
"ROIs": rois}, "spatial_scale": spatial_scale
outputs={"Out": pool_out, })
"Argmax": argmaxes}, return pool_out
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale
})
return pool_out
@templatedoc() @templatedoc()
......
...@@ -100,26 +100,27 @@ Examples: ...@@ -100,26 +100,27 @@ Examples:
>>> result = fluid.layers.hard_shrink(x=data, threshold=0.3) >>> result = fluid.layers.hard_shrink(x=data, threshold=0.3)
""" """
if os.name != 'nt': __all__ += ['cumsum']
__all__ += ['cumsum']
_cum_sum_ = generate_layer_fn('cumsum')
_cum_sum_ = generate_layer_fn('cumsum')
def cumsum(x, axis=None, exclusive=None, reverse=None): def cumsum(x, axis=None, exclusive=None, reverse=None):
locals_var = locals().keys() locals_var = locals().keys()
kwargs = dict() kwargs = dict()
for name in locals_var: for name in locals_var:
val = locals()[name] val = locals()[name]
if val is not None: if val is not None:
kwargs[name] = val kwargs[name] = val
return _cum_sum_(**kwargs) return _cum_sum_(**kwargs)
cumsum.__doc__ = _cum_sum_.__doc__ + """
Examples: cumsum.__doc__ = _cum_sum_.__doc__ + """
Examples:
>>> data = fluid.layers.data(name="input", shape=[32, 784])
>>> result = fluid.layers.cumsum(data, axis=0) >>> data = fluid.layers.data(name="input", shape=[32, 784])
""" >>> result = fluid.layers.cumsum(data, axis=0)
"""
__all__ += ['thresholded_relu'] __all__ += ['thresholded_relu']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册