未验证 提交 0c8c0d94 编写于 作者: G gongweibao 提交者: GitHub

fix macunittest (#13434)

上级 1e44201c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,76 +10,9 @@ See the License for the specific language governing permissions and ...@@ -13,76 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_lstm_compute.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {} // namespace math
// TODO(TJ): ugly workaround, clean me
template <typename T>
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d];
}
}
#ifdef __AVX__
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
template void lstm_compute_ctht<float>(float* gates, const float* ct_1,
float* ct, float* ht);
template void lstm_compute_ctht<double>(double* gates, const double* ct_1,
double* ct, double* ht);
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -14,6 +11,11 @@ limitations under the License. */ ...@@ -14,6 +11,11 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,7 +23,58 @@ namespace math { ...@@ -21,7 +23,58 @@ namespace math {
// TODO(TJ): ugly workaround, clean me // TODO(TJ): ugly workaround, clean me
template <typename T> template <typename T>
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht); void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d];
}
}
#ifdef __AVX__
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): ...@@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
return t return t
from paddle.fluid.transpiler.details import op_to_code
def operator_equal(a, b): def operator_equal(a, b):
if op_to_code(a) != op_to_code(b):
raise ValueError("In operator_equal not equal\n")
for k, v in six.iteritems(a.__dict__): for k, v in six.iteritems(a.__dict__):
if isinstance(v, fluid.framework.Program) or \ if isinstance(v, fluid.framework.Program) or \
isinstance(v, fluid.framework.Block): isinstance(v, fluid.framework.Block):
continue continue
elif isinstance(v, core.OpDesc): elif isinstance(v, core.OpDesc):
if v.serialize_to_string() != b.__dict__[k].serialize_to_string(): continue
raise ValueError("In operator_equal not equal:{0}\n".format(k))
elif isinstance(v, collections.OrderedDict): elif isinstance(v, collections.OrderedDict):
v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0]) v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0])
......
...@@ -113,27 +113,32 @@ def op_to_code(op): ...@@ -113,27 +113,32 @@ def op_to_code(op):
inputs_str += ", " inputs_str += ", "
inputs_str += "}" inputs_str += "}"
attr_names = sorted(op.attr_names)
attrs_str = "" attrs_str = ""
for i in range(0, len(op.attr_names)): for i in range(0, len(attr_names)):
name = op.attr_names[i] name = attr_names[i]
attr_type = op.desc.attr_type(name) attr_type = op.desc.attr_type(name)
if attr_type == core.AttrType.BLOCK: if attr_type == core.AttrType.BLOCK:
a = "{name} = block[{value}]".format( a = "{name} = block[{value}]".format(
name=name, type=attr_type, value=op.block_attr_id(name)) name=name, type=attr_type, value=op.block_attr_id(name))
attrs_str += a attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue continue
if attr_type == core.AttrType.BLOCKS: if attr_type == core.AttrType.BLOCKS:
a = "{name} = blocks{value}".format( a = "{name} = blocks{value}".format(
name=name, type=attr_type, value=op.blocks_attr_ids(name)) name=name, type=attr_type, value=op.blocks_attr_ids(name))
attrs_str += a attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue continue
a = "{name} = {value}".format( a = "{name} = {value}".format(
name=name, type=attr_type, value=op.desc.attr(name)) name=name, type=attr_type, value=op.desc.attr(name))
attrs_str += a attrs_str += a
if i != len(op.attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
if outputs_str != "{}": if outputs_str != "{}":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册