diff --git a/paddle/capi/Main.cpp b/paddle/capi/Main.cpp index bb8249a5511c089ec2f2263ff4cc290f0a5a8fce..c038789340033fcf6dcc07a41b033a50e980c965 100644 --- a/paddle/capi/Main.cpp +++ b/paddle/capi/Main.cpp @@ -43,4 +43,11 @@ paddle_error paddle_init(int argc, char** argv) { isInit = true; return kPD_NO_ERROR; } + +paddle_error paddle_init_thread() { + if (FLAGS_use_gpu) { + hl_init(FLAGS_gpu_id); + } + return kPD_NO_ERROR; +} } diff --git a/paddle/capi/Matrix.cpp b/paddle/capi/Matrix.cpp index 30f3a766f0c65187c8f2dd4603e3d26c9b9a6a3d..cbacd1fb71c14f490ff548db714e728772292b4b 100644 --- a/paddle/capi/Matrix.cpp +++ b/paddle/capi/Matrix.cpp @@ -40,7 +40,7 @@ paddle_error paddle_matrix_destroy(paddle_matrix mat) { paddle_error paddle_matrix_set_row(paddle_matrix mat, uint64_t rowID, paddle_real* rowArray) { - if (mat == nullptr) return kPD_NULLPTR; + if (mat == nullptr || rowArray == nullptr) return kPD_NULLPTR; auto ptr = cast(mat); if (ptr->mat == nullptr) return kPD_NULLPTR; if (rowID >= ptr->mat->getHeight()) return kPD_OUT_OF_RANGE; diff --git a/paddle/capi/error.cpp b/paddle/capi/error.cpp new file mode 100644 index 0000000000000000000000000000000000000000..169b65f92104336d9ec12e2a5a6778db25080270 --- /dev/null +++ b/paddle/capi/error.cpp @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "error.h" + +const char* paddle_error_string(paddle_error err) { + switch (err) { + case kPD_NULLPTR: + return "nullptr error"; + case kPD_OUT_OF_RANGE: + return "out of range error"; + case kPD_PROTOBUF_ERROR: + return "protobuf error"; + case kPD_NOT_SUPPORTED: + return "not supported error"; + case kPD_UNDEFINED_ERROR: + return "undefined error"; + default: + return ""; + } +} diff --git a/paddle/capi/error.h b/paddle/capi/error.h index 44d8c2040d1aad698398089baeee6f13c3deeb55..9d9d0ed63a5276c6b9a8747e1ee1fce6872bdc9e 100644 --- a/paddle/capi/error.h +++ b/paddle/capi/error.h @@ -15,6 +15,8 @@ limitations under the License. */ #ifndef __PADDLE_CAPI_ERROR_H__ #define __PADDLE_CAPI_ERROR_H__ +#include "config.h" + /** * Error Type for Paddle API. */ @@ -27,4 +29,9 @@ typedef enum { kPD_UNDEFINED_ERROR = -1, } paddle_error; +/** + * Error string for Paddle API. + */ +PD_API const char* paddle_error_string(paddle_error err); + #endif diff --git a/paddle/capi/examples/model_inference/multi_thread/CMakeLists.txt b/paddle/capi/examples/model_inference/multi_thread/CMakeLists.txt index 98e411ddc02a46034e8f6ceb00657622d998c9f3..2fc8debddedeab6ae982b0df49ec2b73bc0f85f5 100644 --- a/paddle/capi/examples/model_inference/multi_thread/CMakeLists.txt +++ b/paddle/capi/examples/model_inference/multi_thread/CMakeLists.txt @@ -1,8 +1,29 @@ project(multi_thread) cmake_minimum_required(VERSION 2.8) -aux_source_directory(. SRC_LIST) -add_executable(${PROJECT_NAME} ${SRC_LIST}) + find_package (Threads) + +if(NOT PADDLE_ROOT) + set(PADDLE_ROOT $ENV{PADDLE_ROOT} CACHE PATH "Paddle Path") +endif() +if(PADDLE_ROOT) + include_directories(${PADDLE_ROOT}/include) + link_directories(${PADDLE_ROOT}/lib) +endif() + +set(CPU_SRCS main.c) +add_executable(${PROJECT_NAME} ${CPU_SRCS}) set_property(TARGET ${PROJECT_NAME} PROPERTY C_STANDARD 99) -target_link_libraries(${PROJECT_NAME} -lpaddle_capi_shared - ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${PROJECT_NAME} + -lpaddle_capi_shared + ${CMAKE_THREAD_LIBS_INIT}) + +find_package(CUDA QUIET) +if(CUDA_FOUND) + set(GPU_SRCS main_gpu.c) + cuda_add_executable(${PROJECT_NAME}_gpu ${GPU_SRCS}) + set_property(TARGET ${PROJECT_NAME}_gpu PROPERTY C_STANDARD 99) + target_link_libraries(${PROJECT_NAME}_gpu + -lpaddle_capi_shared + ${CMAKE_THREAD_LIBS_INIT}) +endif(CUDA_FOUND) diff --git a/paddle/capi/examples/model_inference/multi_thread/main_gpu.c b/paddle/capi/examples/model_inference/multi_thread/main_gpu.c new file mode 100644 index 0000000000000000000000000000000000000000..6fd376e0d1a2fee4f9a0f676b53c6f2891795cab --- /dev/null +++ b/paddle/capi/examples/model_inference/multi_thread/main_gpu.c @@ -0,0 +1,113 @@ +#include +#include +#include +#include "../common/common.h" + +#define CONFIG_BIN "./trainer_config.bin" +#define NUM_THREAD 4 +#define NUM_ITER 1000 + +pthread_mutex_t mutex; + +/* + * @brief It is an simple inference example that runs multi-threads on a GPU. + * Each thread holds it own local gradient_machine but shares the same + * parameters. + * If you want to run on different GPUs, you need to launch + * multi-processes or set trainer_count > 1. + */ +void* thread_main(void* gm_ptr) { + // Initialize the thread environment of Paddle. + CHECK(paddle_init_thread()); + + paddle_gradient_machine machine = (paddle_gradient_machine)(gm_ptr); + // Create input arguments. + paddle_arguments in_args = paddle_arguments_create_none(); + // Create input matrix. + paddle_matrix mat = paddle_matrix_create(/* sample_num */ 1, + /* size */ 784, + /* useGPU */ true); + // Create output arguments. + paddle_arguments out_args = paddle_arguments_create_none(); + // Create output matrix. + paddle_matrix prob = paddle_matrix_create_none(); + + // CPU buffer to cache the input and output. + paddle_real* cpu_input = (paddle_real*)malloc(784 * sizeof(paddle_real)); + paddle_real* cpu_output = (paddle_real*)malloc(10 * sizeof(paddle_real)); + for (int iter = 0; iter < NUM_ITER; ++iter) { + // There is only one input layer of this network. + CHECK(paddle_arguments_resize(in_args, 1)); + CHECK(paddle_arguments_set_value(in_args, 0, mat)); + + for (int i = 0; i < 784; ++i) { + cpu_input[i] = rand() / ((float)RAND_MAX); + } + CHECK(paddle_matrix_set_value(mat, cpu_input)); + + CHECK(paddle_gradient_machine_forward(machine, + in_args, + out_args, + /* isTrain */ false)); + + CHECK(paddle_arguments_get_value(out_args, 0, prob)); + CHECK(paddle_matrix_get_value(prob, cpu_output)); + + pthread_mutex_lock(&mutex); + printf("Prob: "); + for (int i = 0; i < 10; ++i) { + printf("%.2f ", cpu_output[i]); + } + printf("\n"); + pthread_mutex_unlock(&mutex); + } + + CHECK(paddle_matrix_destroy(prob)); + CHECK(paddle_arguments_destroy(out_args)); + CHECK(paddle_matrix_destroy(mat)); + CHECK(paddle_arguments_destroy(in_args)); + CHECK(paddle_gradient_machine_destroy(machine)); + + free(cpu_input); + free(cpu_output); + + return NULL; +} + +int main() { + // Initalize Paddle + char* argv[] = {"--use_gpu=True"}; + CHECK(paddle_init(1, (char**)argv)); + + // Reading config binary file. It is generated by `convert_protobin.sh` + long size; + void* buf = read_config(CONFIG_BIN, &size); + + // Create a gradient machine for inference. + paddle_gradient_machine machine; + CHECK(paddle_gradient_machine_create_for_inference(&machine, buf, (int)size)); + CHECK(paddle_gradient_machine_randomize_param(machine)); + + // Loading parameter. Uncomment the following line and change the directory. + // CHECK(paddle_gradient_machine_load_parameter_from_disk(machine, + // "./some_where_to_params")); + srand(time(0)); + pthread_mutex_init(&mutex, NULL); + + pthread_t threads[NUM_THREAD]; + + for (int i = 0; i < NUM_THREAD; ++i) { + paddle_gradient_machine thread_local_machine; + CHECK(paddle_gradient_machine_create_shared_param( + machine, buf, size, &thread_local_machine)); + pthread_create(&threads[i], NULL, thread_main, thread_local_machine); + } + + for (int i = 0; i < NUM_THREAD; ++i) { + pthread_join(threads[i], NULL); + } + + pthread_mutex_destroy(&mutex); + + return 0; +} diff --git a/paddle/capi/main.h b/paddle/capi/main.h index 893ebcbd58dd24cf835fb2005865c94c9ba2a810..99c4e8428dbaa14d36dc2d36b2a4f16c9ec3e0d1 100644 --- a/paddle/capi/main.h +++ b/paddle/capi/main.h @@ -26,6 +26,13 @@ extern "C" { */ PD_API paddle_error paddle_init(int argc, char** argv); +/** + * Initialize the thread environment of Paddle. + * @note it is requisite for GPU runs but optional for CPU runs. + * For GPU runs, all threads will run on the same GPU devices. + */ +PD_API paddle_error paddle_init_thread(); + #ifdef __cplusplus } #endif diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 2281d93df90d0829cede63eaaefade5d5bac776b..cde3f1ac2e411a79f948e0c15a90ec8278a93a29 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -59,7 +59,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); if (in_var->GetType() != VarDesc::LOD_TENSOR) { - VLOG(3) << "input " << in << "is not LodTensor"; + VLOG(3) << "input " << in << " is not LodTensor"; return; } PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR, diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 6134ac78b145e0c9db0146a38f525204d9f11fed..cf522d6921ee746d03d8082b8fc4d051f4d504e6 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -41,14 +41,18 @@ class ConcatOp : public framework::OperatorWithKernel { for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == axis) { out_dims[axis] += ins[i][j]; - continue; + } else { + PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], + "Input tensors should have the same " + "elements except the specify axis."); } - PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], - "Input tensors should have the same " - "elements except the specify axis."); } } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 1e82742eaf86711fe4f9d02d517ad1853131cf67..2b06012b690c6725fd150cd99e992912655dc9c6 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -95,6 +95,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); } protected: diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 7afcdfce9371e29aad968a1729931173fb2309b5..ae4f0bf896dce013d301aa0bf9f732f0fd9cc6bf 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -122,10 +122,6 @@ Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); -} - -void CUDADeviceContext::Finish() const { - Wait(); PADDLE_ENFORCE(cudaGetLastError()); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 526d089e35da9c9f89a3852095ad3a4c82d4d85d..ef5f19214d9ccb23b9c946bee28cb764122bd7cd 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -46,8 +46,6 @@ class DeviceContext { DeviceType* GetEigenDevice() const; virtual void Wait() const {} - - virtual void Finish() const {} }; class CPUDeviceContext : public DeviceContext { @@ -79,9 +77,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Wait for all operations completion in the stream. */ void Wait() const override; - /*! \brief Check potential errors for the cuda kernel calls. */ - void Finish() const override; - /*! \brief Return place in the device context. */ Place GetPlace() const override; diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 3c6ec6fabafb84add4711f0e967028a02058218a..e43b9c218a3ecb9e7f20fb7e8b14a85a29947eef 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -113,7 +113,10 @@ EOF -DWITH_SWIG_PY=ON \ -DWITH_STYLE_CHECK=OFF make -j `nproc` gen_proto_py + make -j `nproc` paddle_python make -j `nproc` paddle_docs paddle_docs_cn + make -j `nproc` print_operators_doc + paddle/pybind/print_operators_doc > doc/en/html/operators.json popd fi @@ -185,14 +188,6 @@ EOF ${DOCKERFILE_GPU_ENV} ADD go/cmd/pserver/pserver /usr/bin/ ADD go/cmd/master/master /usr/bin/ -EOF - - if [[ ${WITH_DOC:-OFF} == 'ON' ]]; then - cat >> /paddle/build/Dockerfile <> /paddle/build/Dockerfile <`_ The example usage is: @@ -3098,7 +3098,7 @@ def img_cmrnorm_layer(input, Reference: `ImageNet Classification with Deep Convolutional Neural Networks - http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf`_ + `_ The example usage is: @@ -3166,7 +3166,7 @@ def batch_norm_layer(input, Reference: `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift - http://arxiv.org/abs/1502.03167`_ + `_ The example usage is: @@ -5424,17 +5424,19 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): Reference: `Maxout Networks - http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf`_ + `_ `Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks - https://arxiv.org/pdf/1312.6082v4.pdf`_ + `_ + .. math:: - y_{si+j} = \max_k x_{gsi + sk + j} - g = groups - s = input.size / num_channels - 0 \le i < num_channels / groups - 0 \le j < s - 0 \le k < groups + out = \max_k (in[n, k, o_c , s]) \\\\ + out_{i * s + j} = \max_k in_{ k * o_{c} * s + i * s + j} \\\\ + s = \frac{input.size}{ num\_channels} \\\\ + o_{c} =\frac{num\_channels}{groups} \\\\ + 0 \le i < o_{c} \\\\ + 0 \le j < s \\\\ + 0 \le k < groups \\\\ The simple usage is: @@ -5493,7 +5495,7 @@ def ctc_layer(input, Reference: `Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks - http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf`_ + `_ Note: Considering the 'blank' label needed by CTC, you need to use (num_classes + 1) @@ -5567,7 +5569,7 @@ def warp_ctc_layer(input, Reference: `Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks - http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf`_ + `_ Note: - Let num_classes represents the category number. Considering the 'blank' @@ -5788,7 +5790,7 @@ def nce_layer(input, Reference: `A fast and simple algorithm for training neural probabilistic language - models. https://www.cs.toronto.edu/~amnih/papers/ncelm.pdf`_ + models. `_ The example usage is: @@ -5904,7 +5906,7 @@ def rank_cost(left, Reference: `Learning to Rank using Gradient Descent - http://research.microsoft.com/en-us/um/people/cburges/papers/ICML_ranking.pdf`_ + `_ .. math:: @@ -6440,7 +6442,7 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None): Reference: `Fast R-CNN - https://arxiv.org/pdf/1504.08083v2.pdf`_ + `_ The example usage is: @@ -6647,7 +6649,7 @@ def prelu_layer(input, Reference: `Delving Deep into Rectifiers: Surpassing Human-Level Performance on - ImageNet Classification http://arxiv.org/pdf/1502.01852v1.pdf`_ + ImageNet Classification `_ .. math:: z_i &\\quad if \\quad z_i > 0 \\\\ @@ -6744,7 +6746,7 @@ def gated_unit_layer(input, Reference: `Language Modeling with Gated Convolutional Networks - https://arxiv.org/abs/1612.08083`_ + `_ .. math:: y=\\text{act}(X \cdot W + b)\otimes \sigma(X \cdot V + c) diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index fb444f2d8698eb355f74188ccf6a6516c06c1086..b4426bad1499419a6b512aa32abfed4fc21ef4c5 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -430,7 +430,8 @@ def _create_op_func_(op_type): dtype = each.dtype elif dtype != each.dtype: raise ValueError( - "operator {0} must input same dtype".format(op_type)) + "operator {0} must input same dtype. {1} vs {2}".format( + op_type, dtype, each.dtype)) return dtype diff --git a/python/paddle/v2/fluid/tests/book/test_machine_translation.py b/python/paddle/v2/fluid/tests/book/test_machine_translation.py index 5bc7e1b59d9e7ae7932c58c3dc938148adf52c78..80ffc5a544c201ed45a6de46b5a2addff82246b7 100644 --- a/python/paddle/v2/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/v2/fluid/tests/book/test_machine_translation.py @@ -1,59 +1,62 @@ import numpy as np import paddle.v2 as paddle -import paddle.v2.dataset.conll05 as conll05 +import paddle.v2.fluid as fluid import paddle.v2.fluid.core as core import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.executor import Executor, g_scope -from paddle.v2.fluid.optimizer import SGDOptimizer -import paddle.v2.fluid as fluid -import paddle.v2.fluid.layers as pd +from paddle.v2.fluid.executor import Executor dict_size = 30000 source_dict_dim = target_dict_dim = dict_size src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size) -hidden_dim = 512 -word_dim = 512 +hidden_dim = 32 +word_dim = 16 IS_SPARSE = True -batch_size = 50 +batch_size = 10 max_length = 50 topk_size = 50 trg_dic_size = 10000 -src_word_id = layers.data(name="src_word_id", shape=[1], dtype='int64') -src_embedding = layers.embedding( - input=src_word_id, - size=[dict_size, word_dim], - dtype='float32', - is_sparse=IS_SPARSE, - param_attr=fluid.ParamAttr(name='vemb')) - - -def encoder(): - - lstm_hidden0, lstm_0 = layers.dynamic_lstm( - input=src_embedding, - size=hidden_dim, - candidate_activation='sigmoid', - cell_activation='sigmoid') - - lstm_hidden1, lstm_1 = layers.dynamic_lstm( - input=src_embedding, - size=hidden_dim, - candidate_activation='sigmoid', - cell_activation='sigmoid', - is_reverse=True) - - bidirect_lstm_out = layers.concat([lstm_hidden0, lstm_hidden1], axis=0) - - return bidirect_lstm_out - - -def decoder_trainer(context): - ''' - decoder with trainer - ''' - pass +decoder_size = hidden_dim + + +def encoder_decoder(): + # encoder + src_word_id = layers.data( + name="src_word_id", shape=[1], dtype='int64', lod_level=1) + src_embedding = layers.embedding( + input=src_word_id, + size=[dict_size, word_dim], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr=fluid.ParamAttr(name='vemb')) + + fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh') + lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4) + encoder_out = layers.sequence_pool(input=lstm_hidden0, pool_type="last") + + # decoder + trg_language_word = layers.data( + name="target_language_word", shape=[1], dtype='int64', lod_level=1) + trg_embedding = layers.embedding( + input=trg_language_word, + size=[dict_size, word_dim], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr=fluid.ParamAttr(name='vemb')) + + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + current_word = rnn.step_input(trg_embedding) + mem = rnn.memory(init=encoder_out) + fc1 = fluid.layers.fc(input=[current_word, mem], + size=decoder_size, + act='tanh') + out = fluid.layers.fc(input=fc1, size=target_dict_dim, act='softmax') + rnn.update_memory(mem, fc1) + rnn.output(out) + + return rnn() def to_lodtensor(data, place): @@ -72,13 +75,18 @@ def to_lodtensor(data, place): def main(): - encoder_out = encoder() - # TODO(jacquesqiao) call here - decoder_trainer(encoder_out) + rnn_out = encoder_decoder() + label = layers.data( + name="target_language_next_word", shape=[1], dtype='int64', lod_level=1) + cost = layers.cross_entropy(input=rnn_out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) + optimizer.minimize(avg_cost) train_data = paddle.batch( paddle.reader.shuffle( - paddle.dataset.wmt14.train(8000), buf_size=1000), + paddle.dataset.wmt14.train(dict_size), buf_size=1000), batch_size=batch_size) place = core.CPUPlace() @@ -88,15 +96,23 @@ def main(): batch_id = 0 for pass_id in xrange(2): - print 'pass_id', pass_id for data in train_data(): - print 'batch', batch_id - batch_id += 1 - if batch_id > 10: break word_data = to_lodtensor(map(lambda x: x[0], data), place) + trg_word = to_lodtensor(map(lambda x: x[1], data), place) + trg_word_next = to_lodtensor(map(lambda x: x[2], data), place) outs = exe.run(framework.default_main_program(), - feed={'src_word_id': word_data, }, - fetch_list=[encoder_out]) + feed={ + 'src_word_id': word_data, + 'target_language_word': trg_word, + 'target_language_next_word': trg_word_next + }, + fetch_list=[avg_cost]) + avg_cost_val = np.array(outs[0]) + print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) + + " avg_cost=" + str(avg_cost_val)) + if batch_id > 3: + exit(0) + batch_id += 1 if __name__ == '__main__': diff --git a/python/setup.py.in b/python/setup.py.in index d59a6a47800291ed744ca225cd765fe6cd207eab..9ccb4dc1762ac761212347fa7c7c94b223d75e24 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -5,7 +5,7 @@ class BinaryDistribution(Distribution): return True MAJOR = 0 -MINOR = 10 +MINOR = 11 PATCH = 0 RC = 0 ISTAGED = False @@ -89,7 +89,7 @@ paddle_rt_libs = ['${WARPCTC_LIBRARIES}'] if '${MKL_SHARED_LIBS}'!= '': paddle_rt_libs += '${MKL_SHARED_LIBS}'.split(';') -setup(name='paddlepaddle', +setup(name='${PACKAGE_NAME}', version='${PADDLE_VERSION}', description='Parallel Distributed Deep Learning', install_requires=setup_requires,