提交 7eb65b31 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into avx_cmake

......@@ -140,19 +140,9 @@ Similarly, the lengths in the top level LoD
are transformed into offsets of elements/words as follows:
```
0 9 10 15
= = =
3+2+4 1+9 2+3+10
```
so we can tell that the first article is from word 0 to word 9, and the second article is from word 9 to word 10.
The complete offset representation is as follows:
```
0 9 10 15
0 3 5 9 10 12 15
||| || |||| | || |||
0 3 4 6
= = =
3 3+1 4+2
```
## Slicing of LoD Tensors
......
......@@ -67,8 +67,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
out);
in_var->SetLoDLevel(out_var->GetLodLevel());
}
bool IsRuntime() const override;
protected:
VarDesc::VarType GetVarType(const std::string &name) const override;
private:
DDim GetDim(const std::string &name) const override;
void SetDim(const std::string &name, const DDim &dim) override;
......@@ -349,6 +352,9 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
info.infer_var_type_(*this, block);
} else {
// all output type is LoDTensor by default
VLOG(10) << this->Type()
<< " has not registered InferVarType. Set output variables to "
"LOD_TENSOR";
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->Var(out_var_name)->SetType(VarDesc::LOD_TENSOR);
......@@ -448,6 +454,12 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
}
} // namespace framework
} // namespace paddle
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include <atomic>
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h"
namespace paddle {
namespace framework {
......@@ -365,7 +367,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_lod(in_tensor.lod());
}
private:
bool IsRuntime() const override { return true; }
protected:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
......@@ -388,6 +392,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
......
......@@ -298,11 +298,10 @@ class ExecutionContext {
}
#ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const {
const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
auto cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
return *cuda_ctx;
return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_);
}
#endif
......
......@@ -46,6 +46,23 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim(names[i], dims[i]);
}
}
std::vector<VarDesc::VarType> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<VarDesc::VarType> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<VarDesc::VarType> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<VarDesc::VarType> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
} // namespace framework
} // namespace paddle
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
......@@ -26,6 +27,10 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
std::vector<VarDesc::VarType> GetInputsVarType(const std::string &name) const;
std::vector<VarDesc::VarType> GetOutputsVarType(
const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;
......@@ -46,6 +51,8 @@ class InferShapeContext {
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual bool IsRuntime() const = 0;
protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
......@@ -55,6 +62,11 @@ class InferShapeContext {
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims);
std::vector<VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const;
virtual VarDesc::VarType GetVarType(const std::string &name) const = 0;
};
} // namespace framework
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
namespace paddle {
namespace framework {
inline VarDesc::VarType ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return VarDesc_VarType_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
}
} // namespace framework
} // namespace paddle
......@@ -48,6 +48,11 @@ class Variable {
void Clear() { holder_.reset(); }
std::type_index Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
return holder_->Type();
}
private:
struct Placeholder {
virtual ~Placeholder() {}
......
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
ThreadLocalD<std::vector<MemoryHandle *>> ConvBaseProjection::convMem_;
ThreadLocalD<std::vector<MemoryHandlePtr>> ConvBaseProjection::convMem_;
ConvBaseProjection::ConvBaseProjection(const ProjectionConfig &config,
ParameterPtr parameter,
......@@ -175,18 +175,18 @@ void ConvBaseProjection::reshape(int batchSize) {
}
void *ConvBaseProjection::getSpaceBytes(size_t size) {
std::vector<MemoryHandle *> &convMem = *convMem_;
std::vector<MemoryHandlePtr> &convMem = *convMem_;
if (convMem.empty()) {
int numDevices = hl_get_device_count();
convMem.resize(numDevices);
}
int devId = hl_get_device();
MemoryHandle **localMem = &(convMem[devId]);
if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
*localMem = new GpuMemoryHandle(size);
MemoryHandlePtr localMem = convMem[devId];
if (NULL == localMem || size > localMem->getAllocSize()) {
localMem = std::make_shared<GpuMemoryHandle>(size);
}
return (*localMem)->getBuf();
return localMem->getBuf();
}
ConvBaseProjection::~ConvBaseProjection() {
......
......@@ -105,7 +105,7 @@ protected:
bool bias_;
std::unique_ptr<Weight> weight_;
static ThreadLocalD<std::vector<MemoryHandle*>> convMem_;
static ThreadLocalD<std::vector<MemoryHandlePtr>> convMem_;
};
} // namespace paddle
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
endif()
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
......@@ -118,6 +118,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
endif()
if ("${TARGET}" STREQUAL "tensor_array_read_write_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(write_to_array);\n")
endif()
# pybind USE_NO_KERNEL_OP
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
file(READ ${TARGET}.cc TARGET_CONTENT)
......@@ -161,6 +166,7 @@ set(DEPS_OPS
sequence_pool_op
lod_rank_table_op
lstm_op
tensor_array_read_write_op
gru_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
......@@ -171,6 +177,7 @@ op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
......
......@@ -72,11 +72,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
}
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
1, PADDLE_CUDA_NUM_THREADS, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(num_samples, infer_width, indices_data, label_data,
accuracy_data);
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
num_samples, infer_width, indices_data, label_data, accuracy_data);
}
};
......
......@@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;
......
......@@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
......
......@@ -130,9 +130,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
dim3 grid_dim(num_x_blocks, batch_size);
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
context.device_context())
.stream();
auto stream = context.cuda_device_context().stream();
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
......@@ -159,9 +157,7 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2;
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
context.device_context())
.stream();
auto stream = context.cuda_device_context().stream();
const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block);
......
......@@ -82,24 +82,19 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
auto stream = ctx.cuda_device_context().stream();
if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
dx_data, dy_data, x_data, label_data, batch_size, class_num);
} else {
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), dx, 0);
auto* label_data = label->data<int64_t>();
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
dx_data, dy_data, x_data, label_data, batch_size, class_num);
}
}
};
......
......@@ -34,15 +34,18 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto dims = framework::make_ddim(shape_int64);
auto output_dim = framework::make_ddim(shape_int64);
int dim_idx = ctx->Attrs().Get<int>("dim_idx");
PADDLE_ENFORCE_GE(dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), dim_idx);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), dim_idx);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
dims[dim_idx] = ctx->GetInputDim("Input")[dim_idx];
ctx->SetOutputDim("Out", dims);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
protected:
......@@ -69,8 +72,11 @@ class FillConstantBatchSizeLikeOpMaker
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("dim_idx",
"(int, default 0) The index of batch size dimension")
AddAttr<int>("input_dim_idx",
"(int, default 0) the index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) the index of output's batch size dimension")
.SetDefault(0);
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
......@@ -86,9 +92,10 @@ Fill up a variable with specified constant value.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp,
ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp,
paddle::framework::EmptyGradOpMaker,
ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>,
......
......@@ -35,7 +35,9 @@ class FillConstantOp : public framework::OperatorWithKernel {
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
int data_type = ctx.Attr<int>("data_type");
VLOG(10) << " FillConstant data_type = " << data_type;
return static_cast<framework::DataType>(data_type);
}
};
......@@ -71,4 +73,5 @@ REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>);
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -20,4 +20,5 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant, ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int>);
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int64_t>);
......@@ -31,7 +31,6 @@ class IncrementOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IncrementOpMaker(framework::OpProto *proto,
......@@ -39,10 +38,10 @@ class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator.");
AddAttr<AttrType>("step",
"(float, default 1.0) "
"The step size by which the "
"input tensor will be incremented.")
AddAttr<float>("step",
"(float, default 1.0) "
"The step size by which the "
"input tensor will be incremented.")
.SetDefault(1.0);
AddComment(R"DOC(
Increment Operator.
......@@ -73,7 +72,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker<float>,
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
ops::IncrementGradOpMaker);
REGISTER_OP_CPU_KERNEL(increment,
ops::IncrementKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CPUPlace, float>,
ops::IncrementKernel<paddle::platform::CPUPlace, double>,
ops::IncrementKernel<paddle::platform::CPUPlace, int>,
ops::IncrementKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -16,4 +16,7 @@
REGISTER_OP_GPU_KERNEL(
increment,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, float>);
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, float>,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, double>,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, int>,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, int64_t>);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace operators {
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class IncrementKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
......@@ -27,7 +27,7 @@ class IncrementKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place());
auto step = static_cast<T>(context.Attr<AttrType>("step"));
auto step = static_cast<T>(context.Attr<float>("step"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
......@@ -74,10 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(output, table, ids, N, K, D);
LookupTable<T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>(
output, table, ids, N, K, D);
}
};
......@@ -95,9 +94,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream();
auto stream = context.cuda_device_context().stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_dim[0]);
......@@ -136,11 +133,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8,
8><<<grids, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(d_table, d_output, ids, N, K, D);
LookupTableGrad<
T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
};
......
......@@ -35,9 +35,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
Tensor index_t_cpu;
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
auto stream = ctx.cuda_device_context().stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
......@@ -73,9 +71,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
auto stream = ctx.device_context().stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);
......
......@@ -64,9 +64,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
auto stream = ctx.cuda_device_context().stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
......
......@@ -24,10 +24,16 @@ class SumOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
auto x_dims = ctx->GetInputsDim("X");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] ==
framework::VarDesc::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}
auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
......@@ -39,6 +45,28 @@ class SumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::LoDTensor>().type());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::SelectedRows>().value().type());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
auto& array = x_vars[0]->Get<framework::LoDTensorArray>();
for (auto& each : array) {
if (each.numel() != 0) {
return framework::ToDataType(each.type());
}
}
}
PADDLE_THROW("Unexpected branch. Input type is %s",
x_vars[0]->Type().name());
}
};
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -63,18 +91,32 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
auto& inputs = op_desc.Input("X");
auto default_var_type = framework::VarDesc::SELECTED_ROWS;
auto var_type = framework::VarDesc::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->Var(name)->GetType() == framework::VarDesc::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = framework::VarDesc::LOD_TENSOR;
auto is_tensor_array = [block](const std::string& name) {
return block->Var(name)->GetType() ==
framework::VarDesc::LOD_TENSOR_ARRAY;
};
bool any_input_is_tensor_array =
std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
bool all_inputs_are_tensor_array =
std::all_of(inputs.begin(), inputs.end(), is_tensor_array);
if (any_input_is_tensor_array) {
PADDLE_ENFORCE(all_inputs_are_tensor_array);
var_type = framework::VarDesc::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) {
var_type = framework::VarDesc::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
block->Var(out_var_name)->SetType(var_type);
}
};
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
......@@ -28,7 +29,7 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X");
int N = in_vars.size();
auto out_var = context.OutputVar("Out");
......@@ -36,7 +37,7 @@ class SumKernel : public framework::OpKernel<T> {
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto* out = context.Output<Tensor>("Out");
auto *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto result = EigenVector<T>::Flatten(*out);
......@@ -51,11 +52,11 @@ class SumKernel : public framework::OpKernel<T> {
// If in_place, just skip the first tensor
for (int i = in_place ? 1 : 0; i < N; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto& in_t = in_vars[i]->Get<framework::LoDTensor>();
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
} else if (in_vars[i]->IsType<framework::SelectedRows>()) {
auto& in_t = in_vars[i]->Get<framework::SelectedRows>();
auto &in_t = in_vars[i]->Get<framework::SelectedRows>();
functor(context.device_context(), in_t, out);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
......@@ -63,8 +64,8 @@ class SumKernel : public framework::OpKernel<T> {
}
} else if (out_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
auto* out = context.Output<SelectedRows>("Out");
auto* out_value = out->mutable_value();
auto *out = context.Output<SelectedRows>("Out");
auto *out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
......@@ -88,9 +89,36 @@ class SumKernel : public framework::OpKernel<T> {
offset, out);
offset += in_vars[i]->Get<SelectedRows>().value().numel();
}
} else if (out_var->IsType<framework::LoDTensorArray>()) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
"Only support all inputs are TensorArray");
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].numel() != 0) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
out_array[i].CopyFrom(in_array[i], in_array[i].place(),
context.device_context());
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
result.device(context.GetEigenDevice<Place>()) = result + in;
}
}
}
}
} else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
out_var->Type().name());
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class ArrayOpBase : public framework::OperatorBase {
public:
ArrayOpBase(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
}
return offset;
}
};
class WriteToArrayOp : public ArrayOpBase {
public:
WriteToArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_tensor = x->Get<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>();
if (offset >= out->size()) {
out->resize(offset + 1);
}
auto *out_tensor = &out->at(offset);
out_tensor->CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx);
out_tensor->set_lod(x_tensor.lod());
}
};
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
WriteToArrayOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput(
"I",
"(Tensor) the subscript index in tensor array. The number of element "
"should be 1");
AddOutput("Out", "(TensorArray) the tensor array will be written");
AddComment(R"DOC(Write a LoDTensor to a LoDTensor array.
Assume T is LoDTensor, i is the subscript of the array, and A is the array. The
equation is
A[i] = T
)DOC");
}
};
class WriteToArrayInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index");
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
"The number of element of subscript index must be 1");
PADDLE_ENFORCE(context->HasInput("X"), NotHasXError());
PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError());
context->SetOutputDim("Out", context->GetInputDim("X"));
}
protected:
virtual const char *NotHasXError() const { return "Must set the lod tensor"; }
virtual const char *NotHasOutError() const {
return "Must set the lod tensor array";
}
};
class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
for (auto &out_var : op_desc.OutputArgumentNames()) {
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
}
}
};
class ReadFromArrayOp : public ArrayOpBase {
public:
ReadFromArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tesnor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
PADDLE_ENFORCE_LT(offset, x_array.size());
out_tesnor->CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx);
out_tesnor->set_lod(x_array[offset].lod());
}
};
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadFromArrayProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I",
"(Tensor) the subscript index in tensor array. The number of "
"element should be 1");
AddOutput("Out", "(LoDTensor) the tensor will be read from.");
AddComment(R"DOC(Read a LoDTensor from a LoDTensor Array
Assume T is LoDTensor, i is th e subscript of the array, and A is the array. The
equation is
T = A[i]
)DOC");
}
};
class ReadFromArrayInferShape : public WriteToArrayInferShape {
protected:
const char *NotHasXError() const override {
return "The input array X must be set";
}
const char *NotHasOutError() const override {
return "The output tensor out must be set";
}
};
class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("read_from_array");
grad_op->SetInput("I", Input("I"));
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("write_to_array");
grad_op->SetInput("I", Input("I"));
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(write_to_array, ops::WriteToArrayOp,
ops::WriteToArrayInferShape, ops::WriteToArrayOpProtoMaker,
ops::WriteToArrayGradMaker, ops::WriteToArrayInferVarType);
REGISTER_OPERATOR(read_from_array, ops::ReadFromArrayOp,
ops::ReadFromArrayInferShape, ops::ReadFromArrayProtoMaker,
ops::ReadFromArrayGradMaker);
include_directories(${CMAKE_CURRENT_BINARY_DIR})
set(OPITMIZER_SRCS
adadelta_optimizer.cc
adagrad_optimizer.cc
......@@ -9,11 +7,6 @@ set(OPITMIZER_SRCS
sgd_optimizer.cc
)
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies})
if(WITH_TESTING)
add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test)
endif()
cc_library(paddle_optimizer STATIC SRCS ${OPITMIZER_SRCS} DEPS paddle_proto glog)
cc_test(serialization_test SRCS serialization_test.cc DEPS paddle_proto)
cc_test(parameter_optimizer_test SRCS parameter_optimizer_test.cc DEPS paddle_optimizer)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "adadelta_optimizer.h"
#include <algorithm>
#include <cmath>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "parameter_optimizer.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include "adagrad_optimizer.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "parameter_optimizer.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "adam_optimizer.h"
#include <cmath>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "parameter_optimizer.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "optimizer.h"
#include <glog/logging.h>
#include <cstdlib>
......@@ -6,8 +20,8 @@
#include "parameter_optimizer.h"
using namespace paddle;
using namespace paddle::optimizer;
using paddle::optimizer::ParameterOptimizer;
using paddle::optimizer::Tensor;
template <paddle_element_type VALUE>
struct EnumToType {};
......@@ -15,22 +29,21 @@ struct EnumToType {};
template <class T>
struct TypeToEnum {};
#define MATCH_ENUM_TYPE(TYPE, ENUM) \
template <> \
struct TypeToEnum<TYPE> { \
static paddle_element_type v() { return ENUM; }; \
static constexpr TYPE value = ENUM; \
}; \
template <> \
struct EnumToType<ENUM> { \
typedef TYPE Type; \
#define MATCH_ENUM_TYPE(TYPE, ENUM) \
template <> \
struct TypeToEnum<TYPE> { \
static paddle_element_type v() { return ENUM; } \
static constexpr TYPE value = ENUM; \
}; \
template <> \
struct EnumToType<ENUM> { \
typedef TYPE Type; \
}
MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
// TODO(zhihong): only implement below type, need to fix
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdbool.h>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
......
......@@ -110,7 +110,7 @@ public:
int s = 0;
float* newp = (float*)opts_[i]->get_weight(&s);
EXPECT_EQ(s, kSize);
EXPECT_EQ(static_cast<size_t>(s), kSize);
for (size_t j = 0; j < kSize; ++j) {
EXPECT_EQ(newp[j], (*p)[j]);
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "sgd_optimizer.h"
#include "serialization.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "parameter_optimizer.h"
......@@ -15,7 +29,6 @@ public:
nesterov_(n) {
if (momentum_ != 0.0) {
size_t size = parameter->size();
// TODO: fix it with align aware allocator bind to Tensor
momentums_ = new Tensor(size);
}
}
......
......@@ -97,6 +97,15 @@ namespace pybind {
using namespace paddle::framework; // NOLINT
template <typename T>
static py::bytes SerializeMessage(T &self) {
// Check IsInitialized in Python
std::string retv;
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
"Cannot serialize message");
return retv;
}
// Bind Methods
void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
......@@ -132,17 +141,7 @@ void BindProgramDesc(py::module &m) {
.def("block", &ProgramDescBind::MutableBlock,
py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes {
const ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"ProgramDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res;
})
.def("serialize_to_string", SerializeMessage<ProgramDescBind>)
.def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto();
......@@ -181,16 +180,7 @@ void BindBlockDesc(py::module &m) {
py::return_value_policy::reference)
.def("op_size", &BlockDescBind::OpSize)
.def("op", &BlockDescBind::Op, py::return_value_policy::reference)
.def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes {
const BlockDesc *desc = block_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"BlockDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize BlockDesc Error. This could be a bug of Paddle.");
return res;
});
.def("serialize_to_string", SerializeMessage<BlockDescBind>);
}
void BindVarDsec(py::module &m) {
......@@ -219,17 +209,7 @@ void BindVarDsec(py::module &m) {
.def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType)
.def("serialize_to_string",
[](VarDescBind &var_desc) -> py::bytes {
const VarDesc *desc = var_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"VarDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize VarDesc Error. This could be a bug of Paddle.");
return res;
})
.def("serialize_to_string", SerializeMessage<VarDescBind>)
.def("persistable", &VarDescBind::Persistable)
.def("set_persistable", &VarDescBind::SetPersistable);
......@@ -274,16 +254,7 @@ void BindOpDesc(py::module &m) {
.def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape)
.def("infer_var_type", &OpDescBind::InferVarType)
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
const OpDesc *desc = op_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"OpDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize OpDesc Error. This could be a bug of Paddle.");
return res;
});
.def("serialize_to_string", SerializeMessage<OpDescBind>);
}
} // namespace pybind
......
......@@ -168,6 +168,7 @@ EOF
${DOCKERFILE_GPU_ENV}
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
ADD paddle/pybind/print_operators_doc /usr/bin/
# default command shows the paddle version and exit
CMD ["paddle", "version"]
EOF
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program
from paddle.v2.framework.framework import Block, Program, g_main_program
g_scope = core.Scope()
......@@ -18,7 +18,7 @@ class Executor(object):
self.executor = core.Executor(act_places)
def run(self,
program,
program=None,
feed=None,
fetch_list=None,
feed_var_name='feed',
......@@ -29,6 +29,9 @@ class Executor(object):
if fetch_list is None:
fetch_list = []
if program is None:
program = g_main_program
if not isinstance(program, Program):
raise TypeError()
......
......@@ -12,6 +12,14 @@ def unique_name(prefix):
return "_".join([prefix, str(uid)])
def _debug_string_(proto):
error_fields = list()
if not proto.IsInitialized(error_fields):
raise ValueError("{0} are not initialized\nThe message is {1}".format(
error_fields, proto))
return proto.__str__()
class Variable(object):
def __init__(self,
block,
......@@ -95,7 +103,7 @@ class Variable(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.VarDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
__repr__ = __str__
......@@ -286,7 +294,7 @@ class Operator(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
__repr__ = __str__
......@@ -343,7 +351,7 @@ class Block(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
__repr__ = __str__
......@@ -448,7 +456,7 @@ class Program(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
def clone(self):
p = Program()
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.framework.initializer import ConstantInitializer, NormalInitializer
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
from paddle.v2.framework.initializer import ConstantInitializer, \
NormalInitializer
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import re
......@@ -579,25 +581,45 @@ class StaticRNN(object):
if self.status != StaticRNN.IN_RNN_BLOCK:
raise ValueError("You must invoke {0} in rnn block".format(method))
def memory(self, init=None, shape=None, dtype=None, init_value=0):
def memory(self,
init=None,
shape=None,
batch_ref=None,
init_value=0.0,
init_batch_dim_idx=0,
ref_batch_dim_idx=1):
'''
:param init: boot memory, if not set, a shape, batch_ref must be provided
:param shape: shape of the boot memory
:param batch_ref: batch size reference variable
:param init_value: the init value of boot memory
:param init_batch_dim_idx: the index of batch size in init's dimension
:param ref_batch_dim_idx: the index of batch size in batch_ref's dimension
:return: boot memory
'''
self._assert_in_rnn_block_('memory')
if init is None:
if shape is None or dtype is None:
if shape is None or batch_ref is None:
raise ValueError(
"if init is None, memory at least need shape and dtype")
"if init is None, memory at least need shape and batch_ref")
parent_block = self.parent_block()
var_name = unique_name("@".join([self.helper.name, "memory_boot"]))
boot_var = parent_block.create_var(
name=var_name, shape=shape, dtype=dtype, persistable=False)
name=var_name,
shape=shape,
dtype=batch_ref.data_type,
persistable=False)
parent_block.append_op(
type="fill_constant",
inputs={},
type="fill_constant_batch_size_like",
inputs={'Input': [batch_ref]},
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': [40] + list(boot_var.shape[1:]),
'data_type': boot_var.data_type
'shape': boot_var.shape,
'data_type': boot_var.data_type,
'input_dim_idx': ref_batch_dim_idx,
'output_dim_idx': init_batch_dim_idx
})
return self.memory(init=boot_var)
......@@ -751,3 +773,68 @@ def lod_rank_table(x, level=0, main_program=None):
outputs={'Out': table},
attrs={'level': level})
return table
def fill_constant(shape, dtype, value, main_program=None):
helper = LayerHelper("ones", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'data_type': out.data_type,
'value': float(value)
})
out.stop_gradient = True
return out
def ones(shape, dtype, main_program=None):
return fill_constant(value=1.0, **locals())
def zeros(shape, dtype, main_program=None):
return fill_constant(value=0.0, **locals())
def increment(x, value=1.0, main_program=None):
helper = LayerHelper("increment", **locals())
tmp = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='increment',
inputs={'X': [x]},
outputs={'Out': [tmp]},
attrs={'step': value})
return tmp
def array_write(x, i, array=None, main_program=None):
helper = LayerHelper('array_write', **locals())
if array is None:
array = helper.create_variable(
name="{0}.out".format(helper.name),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=x.data_type)
helper.append_op(
type='write_to_array',
inputs={'X': [x],
'I': [i]},
outputs={'Out': [array]})
return array
def array_read(array, i, main_program=None):
helper = LayerHelper('array_read', **locals())
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError("array should be tensor array vairable")
out = helper.create_tmp_variable(dtype=array.data_type)
helper.append_op(
type='read_from_array',
inputs={'X': [array],
'I': [i]},
outputs={'Out': [out]})
return out
import unittest
import paddle.v2.framework.core as core
import paddle.v2.framework.layers as layers
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.framework import g_main_program
import numpy
class TestArrayReadWrite(unittest.TestCase):
def test_read_write(self):
x = [
layers.data(
name='x0', shape=[100]), layers.data(
name='x1', shape=[100]), layers.data(
name='x2', shape=[100])
]
for each_x in x:
each_x.stop_gradient = False
i = layers.zeros(shape=[1], dtype='int64')
arr = layers.array_write(x=x[0], i=i)
i = layers.increment(x=i)
i.stop_gradient = True
arr = layers.array_write(x=x[1], i=i, array=arr)
i = layers.increment(x=i)
i.stop_gradient = True
arr = layers.array_write(x=x[2], i=i, array=arr)
i = layers.zeros(shape=[1], dtype='int64')
a0 = layers.array_read(array=arr, i=i)
i = layers.increment(x=i)
i.stop_gradient = True # index should not calculate gradient
a1 = layers.array_read(array=arr, i=i)
i = layers.increment(x=i)
i.stop_gradient = True
a2 = layers.array_read(array=arr, i=i)
mean_a0 = layers.mean(x=a0)
mean_a1 = layers.mean(x=a1)
mean_a2 = layers.mean(x=a2)
a_sum = layers.sums(input=[mean_a0, mean_a1, mean_a2])
mean_x0 = layers.mean(x=x[0])
mean_x1 = layers.mean(x=x[1])
mean_x2 = layers.mean(x=x[2])
x_sum = layers.sums(input=[mean_x0, mean_x1, mean_x2])
scope = core.Scope()
cpu = core.CPUPlace()
exe = Executor(cpu)
tensor = core.LoDTensor()
tensor.set(numpy.random.random(size=(100, 100)).astype('float32'), cpu)
outs = map(numpy.array,
exe.run(feed={'x0': tensor,
'x1': tensor,
'x2': tensor},
fetch_list=[a_sum, x_sum],
scope=scope))
self.assertEqual(outs[0], outs[1])
total_sum = layers.sums(input=[a_sum, x_sum])
total_sum_scaled = layers.scale(x=total_sum, scale=1 / 6.0)
append_backward_ops(total_sum_scaled)
g_vars = map(g_main_program.global_block().var,
[each_x.name + "@GRAD" for each_x in x])
g_out = [
item.sum()
for item in map(
numpy.array,
exe.run(feed={'x0': tensor,
'x1': tensor,
'x2': tensor},
fetch_list=g_vars))
]
g_out_sum = numpy.array(g_out).sum()
# since our final gradient is 1 and the neural network are all linear
# with mean_op.
# the input gradient should also be 1
self.assertAlmostEqual(1.0, g_out_sum, delta=0.1)
if __name__ == '__main__':
unittest.main()
......@@ -21,9 +21,14 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("float32")}
self.attrs = {'value': 3.5, 'shape': [132, -1, 7], 'dim_idx': 1}
out = np.random.random((132, 232, 7)).astype("float32")
self.attrs = {
'value': 3.5,
'shape': [132, -1, 7],
'input_dim_idx': 0,
'output_dim_idx': 1
}
out = np.random.random((132, 219, 7)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}
......
import unittest
from paddle.v2.framework.framework import Program
class TestDebugStringFramework(unittest.TestCase):
def test_debug_str(self):
p = Program()
p.current_block().create_var(name='t', shape=[0, 1])
self.assertRaises(ValueError, callableObj=p.__str__)
if __name__ == '__main__':
unittest.main()
import unittest
import logging
from op_test import get_numeric_gradient
from paddle.v2.framework.layers import *
import paddle.v2.framework.layers as layers
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.backward import append_backward_ops
......@@ -16,8 +13,8 @@ class PyRNNBase(object):
self.x = np.ones(shape=input_shape).astype("float32")
self.y = np.zeros(shape=output_shape).astype("float32")
def step(self):
pass
def step(self, step_id, x):
raise NotImplementedError
def forward(self):
for step_id in range(self.x.shape[0]):
......@@ -116,30 +113,30 @@ class RecurrentOpTest1(unittest.TestCase):
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape)
self.output = mean(x=self.create_rnn_op(), **self.p_info)
self.output = layers.mean(x=self.create_rnn_op(), **self.p_info)
def create_rnn_op(self):
x = data(
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
h_boot = data(
h_boot = layers.data(
shape=[self.input_dim],
data_type='float32',
name='h_boot',
**self.p_info)
h_boot.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN(main_program=self.main_program)
with rnn.step():
h_pre = rnn.memory(init=h_boot)
x_t = rnn.step_input(x)
h = scale(
x=elementwise_add(
h = layers.scale(
x=layers.elementwise_add(
x=h_pre, y=x_t, **self.p_info),
scale=self.py_rnn.scale,
**self.p_info)
......@@ -249,41 +246,41 @@ class RecurrentOpTest2(RecurrentOpTest1):
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)
self.output = mean(x=self.create_rnn_op(), **self.p_info)
self.output = layers.mean(x=self.create_rnn_op(), **self.p_info)
def create_rnn_op(self):
x = data(
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
h_boot = data(
h_boot = layers.data(
shape=[self.input_dim],
data_type='float32',
name='h_boot',
**self.p_info)
h_boot.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN(main_program=self.main_program)
with rnn.step():
h_pre = rnn.memory(init=h_boot)
x_t = rnn.step_input(x)
temp_l = fc(input=x_t,
size=self.input_dim,
param_attr={'name': 'W'},
bias_attr=False,
**self.p_info)
temp_r = fc(input=h_pre,
size=self.input_dim,
param_attr={'name': 'U'},
bias_attr=False,
**self.p_info)
h = sigmoid(
x=elementwise_add(
temp_l = layers.fc(input=x_t,
size=self.input_dim,
param_attr={'name': 'W'},
bias_attr=False,
**self.p_info)
temp_r = layers.fc(input=h_pre,
size=self.input_dim,
param_attr={'name': 'U'},
bias_attr=False,
**self.p_info)
h = layers.sigmoid(
x=layers.elementwise_add(
x=temp_l, y=temp_r, **self.p_info),
**self.p_info)
......@@ -293,7 +290,7 @@ class RecurrentOpTest2(RecurrentOpTest1):
return rnn()
class RecurrentOpTest3(RecurrentOpTest1):
class RecurrentOpMultipleMemoryTest(RecurrentOpTest1):
'''
Test RNNOp with two memories
equation:
......@@ -310,8 +307,8 @@ class RecurrentOpTest3(RecurrentOpTest1):
class PySimpleRNN3(PyRNNBase):
def __init__(self, input_shape, output_shape):
super(RecurrentOpTest3.PySimpleRNN3, self).__init__(input_shape,
output_shape)
super(RecurrentOpMultipleMemoryTest.PySimpleRNN3, self).__init__(
input_shape, output_shape)
seq_len, batch_size, input_dim = input_shape
self.h_boot1 = np.random.normal(size=(batch_size,
......@@ -345,27 +342,27 @@ class RecurrentOpTest3(RecurrentOpTest1):
self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = RecurrentOpTest3.PySimpleRNN3(self.input_shape,
self.output_shape)
self.py_rnn = RecurrentOpMultipleMemoryTest.PySimpleRNN3(
self.input_shape, self.output_shape)
self.output = mean(x=self.create_rnn_op(), **self.p_info)
self.output = layers.mean(x=self.create_rnn_op(), **self.p_info)
def create_rnn_op(self):
x = data(
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
h_boot1 = data(
h_boot1 = layers.data(
shape=[self.batch_size, self.input_dim],
data_type='float32',
name='h_boot1',
append_batch_size=False,
**self.p_info)
h_boot1.stop_gradient = False
h_boot2 = data(
h_boot2 = layers.data(
shape=[self.batch_size, self.input_dim],
data_type='float32',
name='h_boot2',
......@@ -373,15 +370,15 @@ class RecurrentOpTest3(RecurrentOpTest1):
**self.p_info)
h_boot2.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN(main_program=self.main_program)
with rnn.step():
h_pre1 = rnn.memory(init=h_boot1)
h_pre2 = rnn.memory(init=h_boot2)
x_t = rnn.step_input(x)
mem1 = scale(x=h_pre1, scale=1.0, **self.p_info)
mem2 = scale(x=h_pre2, scale=1.0, **self.p_info)
out = sums(input=[mem1, x_t, mem2], **self.p_info)
mem1 = layers.scale(x=h_pre1, scale=1.0, **self.p_info)
mem2 = layers.scale(x=h_pre2, scale=1.0, **self.p_info)
out = layers.sums(input=[mem1, x_t, mem2], **self.p_info)
rnn.update_memory(h_pre1, mem1)
rnn.update_memory(h_pre2, mem2)
......@@ -390,5 +387,70 @@ class RecurrentOpTest3(RecurrentOpTest1):
return rnn()
class RecurrentOpNoMemBootTest(RecurrentOpTest1):
'''
Test RNNOp with two memories
equation:
mem = x + mem_pre
y = mem
vars:
- x
memories:
- mem
outputs:
- y
'''
class PySimpleRNN4(PyRNNBase):
def __init__(self, input_shape, output_shape):
super(RecurrentOpNoMemBootTest.PySimpleRNN4, self).__init__(
input_shape, output_shape)
men_dim = input_shape
self.mems = np.zeros(shape=men_dim).astype("float32")
def step(self, step_id, x):
if step_id == 0:
pre_mem = np.zeros_like(x)
else:
pre_mem = self.mems[step_id - 1]
self.mems[step_id] = pre_mem + x
self.y[step_id] = self.mems[step_id]
input_dim = 1
batch_size = 1
sent_len = 2
def setUp(self):
self.setup_program()
self.data_field = {"x"}
self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = RecurrentOpNoMemBootTest.PySimpleRNN4(self.input_shape,
self.output_shape)
self.output = layers.mean(x=self.create_rnn_op(), **self.p_info)
print self.main_program
def create_rnn_op(self):
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
data_type='float32',
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
rnn = layers.StaticRNN(main_program=self.main_program)
with rnn.step():
mem_pre = rnn.memory(shape=[-1, self.input_dim], batch_ref=x)
x_t = rnn.step_input(x)
mem = layers.elementwise_add(x=mem_pre, y=x_t, **self.p_info)
rnn.update_memory(mem_pre, mem)
rnn.output(mem)
return rnn()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册