提交 49c28b8c 编写于 作者: N nhzlx

Merge branch 'develop' of https://github.com/paddlepaddle/paddle into add_params_sync_pass

test=develop
...@@ -27,13 +27,14 @@ SET(GZSTREAM_INCLUDE_DIR "${GZSTREAM_INSTALL_DIR}/include/" CACHE PATH "gzstream ...@@ -27,13 +27,14 @@ SET(GZSTREAM_INCLUDE_DIR "${GZSTREAM_INSTALL_DIR}/include/" CACHE PATH "gzstream
ExternalProject_Add( ExternalProject_Add(
extern_gzstream extern_gzstream
DEPENDS zlib
GIT_REPOSITORY "https://github.com/jacquesqiao/gzstream.git" GIT_REPOSITORY "https://github.com/jacquesqiao/gzstream.git"
GIT_TAG "" GIT_TAG ""
PREFIX ${GZSTREAM_SOURCES_DIR} PREFIX ${GZSTREAM_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1 BUILD_IN_SOURCE 1
BUILD_COMMAND make -j8 BUILD_COMMAND make EXTERN_CPPFLAGS="-I${THIRD_PARTY_PATH}/install/zlib/include" EXTERM_LDFLAGS="-L${THIRD_PARTY_PATH}/install/zlib/lib" -j8
INSTALL_COMMAND mkdir -p ${GZSTREAM_INSTALL_DIR}/lib/ && mkdir -p ${GZSTREAM_INSTALL_DIR}/include/ INSTALL_COMMAND mkdir -p ${GZSTREAM_INSTALL_DIR}/lib/ && mkdir -p ${GZSTREAM_INSTALL_DIR}/include/
&& cp ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/libgzstream.a ${GZSTREAM_INSTALL_DIR}/lib && cp ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/libgzstream.a ${GZSTREAM_INSTALL_DIR}/lib
&& cp -r ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/gzstream.h ${GZSTREAM_INSTALL_DIR}/include && cp -r ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/gzstream.h ${GZSTREAM_INSTALL_DIR}/include
......
...@@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
"The %s[%d] is @EMPTY@", out, j); "The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR) { if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
VLOG(3) << "input " << in << " is not LodTensor"; in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
return; return;
} }
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
} }
void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", in, i);
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
out_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensorArray or LodTensor.",
out_var->Name());
PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensor.", in_var->Name());
if (in_var->GetLoDLevel() > 0) {
out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
}
}
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
......
...@@ -623,6 +623,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -623,6 +623,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
void DecreaseLoDLevel(const std::string& in, const std::string& out,
size_t i = 0, size_t j = 0) const override {
PADDLE_THROW("DecreaseLoDLevel is only used in compile time.");
}
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
protected: protected:
......
...@@ -62,6 +62,9 @@ class InferShapeContext { ...@@ -62,6 +62,9 @@ class InferShapeContext {
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
virtual void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name); std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
......
...@@ -54,6 +54,9 @@ mkdir -p build ...@@ -54,6 +54,9 @@ mkdir -p build
cd build cd build
for WITH_STATIC_LIB in ON OFF; do for WITH_STATIC_LIB in ON OFF; do
# TODO(Superjomn) reopen this
# something wrong with the TensorArray reset.
:<<D
# -----simple_on_word2vec----- # -----simple_on_word2vec-----
rm -rf * rm -rf *
cmake .. -DPADDLE_LIB=${inference_install_dir} \ cmake .. -DPADDLE_LIB=${inference_install_dir} \
...@@ -74,6 +77,7 @@ for WITH_STATIC_LIB in ON OFF; do ...@@ -74,6 +77,7 @@ for WITH_STATIC_LIB in ON OFF; do
fi fi
done done
fi fi
D
# ---------vis_demo--------- # ---------vis_demo---------
rm -rf * rm -rf *
cmake .. -DPADDLE_LIB=${inference_install_dir} \ cmake .. -DPADDLE_LIB=${inference_install_dir} \
......
...@@ -167,6 +167,19 @@ $$T = A[i]$$ ...@@ -167,6 +167,19 @@ $$T = A[i]$$
}; };
class ReadFromArrayInferShape : public WriteToArrayInferShape { class ReadFromArrayInferShape : public WriteToArrayInferShape {
public:
void operator()(framework::InferShapeContext *context) const override {
WriteToArrayInferShape::operator()(context);
if (!context->HasInput("X")) {
return;
}
// FIXME: just for compile time.
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
}
protected: protected:
const char *NotHasXError() const override { const char *NotHasXError() const override {
return "The input array X must be set"; return "The input array X must be set";
......
...@@ -192,6 +192,10 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { ...@@ -192,6 +192,10 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
// The first dim of each LoDTensor in Output can only be set at run-time.; // The first dim of each LoDTensor in Output can only be set at run-time.;
// We still have to Resize each LoDTensor in Output. // We still have to Resize each LoDTensor in Output.
context->SetOutputDim("Out", x_dim); context->SetOutputDim("Out", x_dim);
// The lod level should be passed to out in compile time.
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
} }
}; };
......
...@@ -168,6 +168,9 @@ class Blas { ...@@ -168,6 +168,9 @@ class Blas {
template <typename T> template <typename T>
void SCAL(int n, const T a, T* x) const; void SCAL(int n, const T a, T* x) const;
template <typename T>
T ASUM(int n, T* x, int inc) const;
template <typename T> template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C, int K, T alpha, const T* A, const T* B, T beta, T* C,
...@@ -269,6 +272,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -269,6 +272,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template SCAL<T>(args...); Base()->template SCAL<T>(args...);
} }
template <typename... ARGS>
T ASUM(ARGS... args) const {
return Base()->template ASUM<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void BatchedGEMM(ARGS... args) const { void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...); Base()->template BatchedGEMM<T>(args...);
......
...@@ -84,6 +84,11 @@ struct CBlas<float> { ...@@ -84,6 +84,11 @@ struct CBlas<float> {
platform::dynload::cblas_sscal(args...); platform::dynload::cblas_sscal(args...);
} }
template <typename... ARGS>
static float ASUM(ARGS... args) {
return platform::dynload::cblas_sasum(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...); platform::dynload::cblas_sgemm_batch(args...);
...@@ -174,6 +179,11 @@ struct CBlas<double> { ...@@ -174,6 +179,11 @@ struct CBlas<double> {
platform::dynload::cblas_dscal(args...); platform::dynload::cblas_dscal(args...);
} }
template <typename... ARGS>
static double ASUM(ARGS... args) {
return platform::dynload::cblas_dasum(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...); platform::dynload::cblas_dgemm_batch(args...);
...@@ -268,6 +278,7 @@ struct CBlas<platform::float16> { ...@@ -268,6 +278,7 @@ struct CBlas<platform::float16> {
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) { static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
...@@ -476,6 +487,21 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const { ...@@ -476,6 +487,21 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
#endif #endif
} }
template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const {
auto sum = static_cast<T>(0.0);
#ifdef PADDLE_WITH_MKLML
sum = CBlas<T>::ASUM(n, x, inc);
#else
// TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum
for (int c = 0; c < n; ++c) {
sum += x[c];
}
#endif
return sum;
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
...@@ -100,11 +100,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -100,11 +100,8 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
blas.VEXP(num_classes * batch_size, out_data, out_data); blas.VEXP(num_classes * batch_size, out_data, out_data);
for (int n = 0; n < batch_size; ++n) { for (int n = 0; n < batch_size; ++n) {
entities[n] = out_data[n * num_classes]; auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1);
for (int c = 1; c < num_classes; ++c) { blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]);
entities[n] += out_data[n * num_classes + c];
}
blas.SCAL(num_classes, 1.0f / entities[n], &out_data[n * num_classes]);
} }
} }
}; };
......
...@@ -201,6 +201,9 @@ class IdentityInferShape : public framework::InferShapeBase { ...@@ -201,6 +201,9 @@ class IdentityInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
} }
}; };
......
...@@ -100,6 +100,9 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase { ...@@ -100,6 +100,9 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(context->HasInput("I")); PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable")); PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
} }
}; };
......
...@@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -36,9 +36,7 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
#ifdef PADDLE_ON_INFERENCE #ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor< math::SoftmaxFunctor<DeviceContext, T, true>()(
DeviceContext, T,
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#else #else
math::SoftmaxFunctor<DeviceContext, T, false>()( math::SoftmaxFunctor<DeviceContext, T, false>()(
......
...@@ -68,6 +68,8 @@ extern void* mklml_dso_handle; ...@@ -68,6 +68,8 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_batch); \ __macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \ __macro(cblas_sdot); \
__macro(cblas_ddot); \ __macro(cblas_ddot); \
__macro(cblas_sasum); \
__macro(cblas_dasum); \
__macro(cblas_sscal); \ __macro(cblas_sscal); \
__macro(cblas_dscal); \ __macro(cblas_dscal); \
__macro(vsAdd); \ __macro(vsAdd); \
......
...@@ -398,7 +398,26 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -398,7 +398,26 @@ All parameter, weight, gradient are variables in Paddle.
}, },
py::return_value_policy::copy); py::return_value_policy::copy);
py::class_<Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", R"DOC(
Scope is an association of a name to Variable. All variables belong to Scope.
Variables in a parent scope can be retrieved from local scope.
You need to specify a scope to run a Net, i.e., `exe.Run(&scope)`.
One net can run in different scopes and update different variable in the
scope.
You can create var in a scope and get it from the scope.
Examples:
.. code-block:: python
# create tensor from a scope and set value to it.
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
)DOC")
.def("var", .def("var",
[](Scope &self, const std::string &name) -> Variable * { [](Scope &self, const std::string &name) -> Variable * {
return self.Var(name); return self.Var(name);
......
...@@ -71,15 +71,16 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -71,15 +71,16 @@ def __build_dict(tar_file, dict_size, save_path, lang):
for w in sen.split(): for w in sen.split():
word_dict[w] += 1 word_dict[w] += 1
with open(save_path, "w") as fout: with open(save_path, "wb") as fout:
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) fout.write(
cpt.to_bytes("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)))
for idx, word in enumerate( for idx, word in enumerate(
sorted( sorted(
six.iteritems(word_dict), key=lambda x: x[1], six.iteritems(word_dict), key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write(word[0].encode('utf-8')) fout.write(cpt.to_bytes(word[0]))
fout.write('\n') fout.write(cpt.to_bytes('\n'))
def __load_dict(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
......
...@@ -33,13 +33,15 @@ def force_init_on_cpu(): ...@@ -33,13 +33,15 @@ def force_init_on_cpu():
""" """
The flag of whether force to init variables on CPU. The flag of whether force to init variables on CPU.
Returns:: Returns:
bool: the state if we should force init on CPU.
Examples: Examples:
.. code-block:: python .. code-block:: python
if force_init_on_cpu(): if force_init_on_cpu():
pass create_op('force_cpu': force_init_on_cpu())
""" """
return _force_init_on_cpu_ return _force_init_on_cpu_
......
...@@ -10,6 +10,8 @@ else() ...@@ -10,6 +10,8 @@ else()
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
if(${src} STREQUAL "test_recognize_digits_conv") if(${src} STREQUAL "test_recognize_digits_conv")
message(WARNING "These tests has been disabled in OSX for random fail: \n" ${src}) message(WARNING "These tests has been disabled in OSX for random fail: \n" ${src})
elseif(${src} STREQUAL "test_recognize_digits_mlp")
message(WARNING "These tests has been disabled in OSX for random fail: \n" ${src})
else() else()
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
endif() endif()
......
...@@ -172,6 +172,7 @@ class TestDynRNN(unittest.TestCase): ...@@ -172,6 +172,7 @@ class TestDynRNN(unittest.TestCase):
rnn = fluid.layers.DynamicRNN() rnn = fluid.layers.DynamicRNN()
with rnn.block(): with rnn.block():
in_ = rnn.step_input(sentence) in_ = rnn.step_input(sentence)
assert in_.lod_level == 1, "the lod level of in_ should be 1"
sent_emb = fluid.layers.embedding( sent_emb = fluid.layers.embedding(
input=in_, size=[len(word_dict), 32], dtype='float32') input=in_, size=[len(word_dict), 32], dtype='float32')
out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh') out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh')
...@@ -179,6 +180,7 @@ class TestDynRNN(unittest.TestCase): ...@@ -179,6 +180,7 @@ class TestDynRNN(unittest.TestCase):
rnn1 = fluid.layers.DynamicRNN() rnn1 = fluid.layers.DynamicRNN()
with rnn1.block(): with rnn1.block():
in_1 = rnn1.step_input(out_) in_1 = rnn1.step_input(out_)
assert in_1.lod_level == 0, "the lod level of in_1 should be 0"
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh') out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1) rnn1.output(out_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册