提交 ec01f635 编写于 作者: Y Yang Yang

merge develop

......@@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
grad->SetDataType(proto::DataType::FP32);
grad->SetDataType(proto::VarType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
......
......@@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel {
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
......
......@@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) {
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
......
......@@ -20,35 +20,35 @@ limitations under the License. */
namespace paddle {
namespace framework {
inline proto::DataType ToDataType(std::type_index type) {
inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
return proto::VarType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
return proto::VarType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64;
return proto::VarType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL;
return proto::VarType::BOOL;
} else {
PADDLE_THROW("Not supported");
}
}
inline std::type_index ToTypeIndex(proto::DataType type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
return typeid(float);
case DataType::FP64:
case proto::VarType::FP64:
return typeid(double);
case DataType::INT32:
case proto::VarType::INT32:
return typeid(int);
case DataType::INT64:
case proto::VarType::INT64:
return typeid(int64_t);
case DataType::BOOL:
case proto::VarType::BOOL:
return typeid(bool);
default:
PADDLE_THROW("Not support type %d", type);
......@@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
}
template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) {
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
visitor.template operator()<float>();
break;
case DataType::FP64:
case proto::VarType::FP64:
visitor.template operator()<double>();
break;
case DataType::INT32:
case proto::VarType::INT32:
visitor.template operator()<int>();
break;
case DataType::INT64:
case proto::VarType::INT64:
visitor.template operator()<int64_t>();
break;
case DataType::BOOL:
case proto::VarType::BOOL:
visitor.template operator()<bool>();
break;
default:
......@@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
}
}
inline std::string DataTypeToString(const proto::DataType type) {
inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP16:
case proto::VarType::FP16:
return "float16";
case DataType::FP32:
case proto::VarType::FP32:
return "float32";
case DataType::FP64:
case proto::VarType::FP64:
return "float64";
case DataType::INT16:
case proto::VarType::INT16:
return "int16";
case DataType::INT32:
case proto::VarType::INT32:
return "int32";
case DataType::INT64:
case proto::VarType::INT64:
return "int64";
case DataType::BOOL:
case proto::VarType::BOOL:
return "bool";
default:
PADDLE_THROW("Not support type %d", type);
......@@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
}
inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) {
const proto::VarType::Type& type) {
out << DataTypeToString(type);
return out;
}
......
......@@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place());
switch (src_type) {
case proto::DataType::FP32:
case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break;
case proto::DataType::FP64:
case proto::VarType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break;
case proto::DataType::INT32:
case proto::VarType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break;
case proto::DataType::INT64:
case proto::VarType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break;
case proto::DataType::BOOL:
case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
default:
......
......@@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) {
ptr[i] = i / 3;
}
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place,
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place,
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place,
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out);
......
......@@ -91,34 +91,35 @@ message OpProto {
required string comment = 5;
}
enum DataType {
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message VarType {
enum Type {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
NCCL_COM = 10;
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
NCCL_COM = 16;
}
required Type type = 1;
message TensorDesc {
required DataType data_type = 1;
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
......
......@@ -40,12 +40,12 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_;
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
OpKernelType(proto::DataType data_type, platform::Place place,
OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
......@@ -53,7 +53,7 @@ struct OpKernelType {
place_(place),
library_type_(library_type) {}
OpKernelType(proto::DataType data_type,
OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
......
......@@ -18,7 +18,7 @@ limitations under the License. */
TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
......@@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) {
TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout;
......
......@@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
}
};
......@@ -290,9 +290,9 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};
......
......@@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
proto::DataType OperatorWithKernel::IndicateDataType(
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
......@@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType(
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::DataType>(data_type);
return static_cast<proto::VarType::Type>(data_type);
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
......
......@@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
};
......
......@@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......@@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......
......@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::proto::DataType::FP32);
var->SetDataType(paddle::framework::proto::VarType::FP32);
}
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
......@@ -52,7 +53,9 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
};
static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t> functor;
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
platform::float16>
functor;
size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size;
......
......@@ -87,12 +87,12 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res;
}
void VarDesc::SetDataType(proto::DataType data_type) {
void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size()
......@@ -108,13 +108,13 @@ void VarDesc::SetDataTypes(
}
}
proto::DataType VarDesc::GetDataType() const {
proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type();
}
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
std::vector<proto::VarType::Type> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
......
......@@ -80,13 +80,14 @@ class VarDesc {
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type);
void SetDataType(proto::VarType::Type data_type);
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);
void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);
proto::DataType GetDataType() const;
proto::VarType::Type GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const;
std::vector<proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level);
......
......@@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace());
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) "
"Shape of values.");
AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32,
framework::proto::DataType::FP32});
.InEnum({framework::proto::VarType::INT32,
framework::proto::VarType::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values")
......
......@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr;
switch (dtype) {
case framework::proto::DataType::INT32:
case framework::proto::VarType::INT32:
value_name = "int32_values";
break;
case framework::proto::DataType::FP32:
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
default:
......
......@@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
}
......
......@@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
};
......
......@@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
......
......@@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddComment(R"DOC(
......
......@@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype"));
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
......@@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
......
......@@ -51,7 +51,8 @@ class FillOp : public framework::OperatorBase {
"Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto dtype =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
......@@ -93,7 +94,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
"value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor");
AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("force_cpu",
"Whether the output tensor must be at CPU memory or not. "
"Default is false.")
......
......@@ -26,7 +26,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -53,7 +53,7 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
......
......@@ -63,7 +63,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -95,7 +95,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
......
......@@ -41,10 +41,10 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
int img_height = in_dim[2];
int img_width = in_dim[3];
int output_height = OutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width =
OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]);
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
ctx->SetOutputDim("Out", {batch_size * output_height * output_width,
img_channels * kernels[0] * kernels[1]});
......
......@@ -26,8 +26,8 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
inline int OutputSize(int input_size, int filter_size, int padding_0,
int padding_1, int stride) {
inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0,
int padding_1, int stride) {
const int output_size =
(input_size + padding_0 + padding_1 - filter_size) / stride + 1;
return output_size;
......@@ -53,10 +53,10 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
auto kernels = ctx.Attr<std::vector<int>>("kernels");
auto strides = ctx.Attr<std::vector<int>>("strides");
auto paddings = ctx.Attr<std::vector<int>>("paddings");
int output_height = OutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width =
OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]);
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
const std::vector<int> dilations({1, 1});
......@@ -109,10 +109,10 @@ class Im2SequenceGradKernel : public framework::OpKernel<T> {
auto kernels = ctx.Attr<std::vector<int>>("kernels");
auto strides = ctx.Attr<std::vector<int>>("strides");
auto paddings = ctx.Attr<std::vector<int>>("paddings");
int output_height = OutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width =
OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]);
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
const std::vector<int> dilations({1, 1});
......
......@@ -60,7 +60,7 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"An integer to specify the data type of one-hot "
"vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::DataType::FP32);
.SetDefault(paddle::framework::proto::VarType::FP32);
AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
......
......@@ -65,7 +65,8 @@ class OneHotCUDAKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
......
......@@ -58,7 +58,8 @@ class OneHotKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
......
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
int PoolOutputSize(int input_size, int filter_size, int padding, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
......@@ -55,7 +55,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
PoolOutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out");
......
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
inline int OutputSizeMaxPool(int input_size, int filter_size, int padding,
inline int MaxPoolOutputSize(int input_size, int filter_size, int padding,
int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
......@@ -61,7 +61,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i],
output_shape.push_back(MaxPoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
......
......@@ -66,7 +66,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment("");
}
};
......@@ -126,7 +126,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment("");
}
};
......
......@@ -73,7 +73,8 @@ class SumOp : public framework::OperatorWithKernel {
"Sum operator should have at least one tensor");
return framework::OpKernelType(
static_cast<framework::proto::DataType>(dtype), ctx.device_context());
static_cast<framework::proto::VarType::Type>(dtype),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -26,7 +26,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -58,7 +58,7 @@ This operator initializes a tensor with the same batch_size as the Input tensor
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
}
};
......
......@@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -101,7 +101,7 @@ uniform distribution.
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
}
};
} // namespace operators
......
......@@ -64,7 +64,7 @@ Paper: http://www.matthewzeiler.com/wp-content/uploads/2017/07/iccv2011.pdf
}
};
int OutputSize(int input_size, int ksize, int padding, int stride) {
int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) {
int output_size = (input_size - 1) * stride - 2 * padding + ksize;
return output_size;
}
......@@ -101,8 +101,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
......
......@@ -62,6 +62,7 @@ limitations under the License. */
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
namespace paddle {
namespace platform {
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
......@@ -71,11 +72,21 @@ struct PADDLE_ALIGN(2) float16 {
public:
uint16_t x;
// Constructors
HOSTDEVICE inline float16() : x(0) {}
// The following defaulted special class member functions
// are added to make float16 pass the std::is_trivial test
HOSTDEVICE inline float16() = default;
HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
HOSTDEVICE inline float16(const float16&) = default;
HOSTDEVICE inline float16& operator=(const float16&) = default;
HOSTDEVICE inline float16(float16&&) = default;
HOSTDEVICE inline float16& operator=(float16&&) = default;
HOSTDEVICE inline ~float16() = default;
// Constructors
#ifdef PADDLE_CUDA_FP16
HOSTDEVICE inline explicit float16(const half& h) {
#if CUDA_VERSION >= 9000
......@@ -136,11 +147,6 @@ struct PADDLE_ALIGN(2) float16 {
HOSTDEVICE inline explicit float16(const T& val)
: x(float16(static_cast<float>(val)).x) {}
HOSTDEVICE inline float16& operator=(const float16& rhs) {
x = rhs.x;
return *this;
}
// Assignment operators
#ifdef PADDLE_CUDA_FP16
HOSTDEVICE inline float16& operator=(const half& rhs) {
......@@ -727,4 +733,25 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif
} // namespace platform
} // namespace paddle
namespace std {
// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
static const bool value =
is_trivial<paddle::platform::float16>::value &&
is_standard_layout<paddle::platform::float16>::value;
};
} // namespace std
......@@ -10,10 +10,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include <gtest/gtest.h>
namespace paddle {
namespace platform {
TEST(float16, conversion_cpu) {
// Explicit conversion from Eigen::half
......@@ -54,13 +57,9 @@ TEST(float16, conversion_cpu) {
EXPECT_EQ(float16(true).x, 0x3c00);
EXPECT_EQ(float16(false).x, 0x0000);
// Default constructor
float16 v_def;
EXPECT_EQ(v_def.x, 0x0000);
// Assignment operator
float16 v_assign;
v_assign = v_def;
v_assign = float16(0);
EXPECT_EQ(v_assign.x, 0x0000);
v_assign = Eigen::half(1.0f);
EXPECT_EQ(v_assign.x, 0x3c00);
......@@ -116,4 +115,27 @@ TEST(float16, comparison_cpu) {
EXPECT_FALSE(float16(-0.0f) > float16(0.0f));
}
TEST(float16, lod_tensor_cpu) {
framework::LoDTensor lod_tensor;
std::vector<float16> input_data = {float16(1.0f), float16(0.5f),
float16(0.33333f), float16(0.0f)};
EXPECT_EQ(input_data[0].x, 0x3c00);
EXPECT_EQ(input_data[1].x, 0x3800);
EXPECT_EQ(input_data[2].x, 0x3555);
EXPECT_EQ(input_data[3].x, 0x0000);
lod_tensor.Resize({4, 1});
lod_tensor.set_lod(framework::LoD({{0, 2, 4}}));
float16* data_ptr = lod_tensor.mutable_data<float16>(CPUPlace());
EXPECT_NE(data_ptr, nullptr);
EXPECT_EQ(input_data.size(), static_cast<size_t>(lod_tensor.numel()));
for (size_t i = 0; i < input_data.size(); ++i) {
data_ptr[i] = input_data[i];
EXPECT_EQ(data_ptr[i].x, input_data[i].x);
}
}
} // namespace platform
} // namespace paddle
......@@ -13,6 +13,8 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
......@@ -108,6 +110,7 @@ limitations under the License. */
#ifdef PADDLE_CUDA_FP16
namespace paddle {
namespace platform {
#if CUDA_VERSION < 9000
ARITHMETIC_KERNEL(Add, +)
......@@ -209,5 +212,35 @@ TEST(float16, conversion_on_gpu) {
EXPECT_EQ(v_assign.x, 0x3c00);
}
TEST(float16, lod_tensor_on_gpu) {
framework::LoDTensor src_tensor;
framework::LoDTensor gpu_tensor;
framework::LoDTensor dst_tensor;
float16* src_ptr = src_tensor.mutable_data<float16>(
framework::make_ddim({2, 2}), CPUPlace());
float16 arr[4] = {float16(1.0f), float16(0.5f), float16(0.33333f),
float16(0.0f)};
memcpy(src_ptr, arr, 4 * sizeof(float16));
// CPU LoDTensor to GPU LoDTensor
CUDAPlace gpu_place(0);
CUDADeviceContext gpu_ctx(gpu_place);
framework::TensorCopy(src_tensor, gpu_place, gpu_ctx, &gpu_tensor);
// GPU LoDTensor to CPU LoDTensor
framework::TensorCopy(gpu_tensor, CPUPlace(), gpu_ctx, &dst_tensor);
// Sync before comparing LoDTensors
gpu_ctx.Wait();
const float16* dst_ptr = dst_tensor.data<float16>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 4; ++i) {
EXPECT_EQ(src_ptr[i].x, dst_ptr[i].x);
}
}
} // namespace platform
} // namespace paddle
#endif // PADDLE_CUDA_FP16
......@@ -195,15 +195,6 @@ void BindBlockDesc(py::module &m) {
}
void BindVarDsec(py::module &m) {
py::enum_<proto::DataType>(m, "DataType", "")
.value("BOOL", proto::DataType::BOOL)
.value("INT16", proto::DataType::INT16)
.value("INT32", proto::DataType::INT32)
.value("INT64", proto::DataType::INT64)
.value("FP16", proto::DataType::FP16)
.value("FP32", proto::DataType::FP32)
.value("FP64", proto::DataType::FP64);
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
.def("name",
......@@ -233,6 +224,13 @@ void BindVarDsec(py::module &m) {
.def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", proto::VarType::BOOL)
.value("INT16", proto::VarType::INT16)
.value("INT32", proto::VarType::INT32)
.value("INT64", proto::VarType::INT64)
.value("FP16", proto::VarType::FP16)
.value("FP32", proto::VarType::FP32)
.value("FP64", proto::VarType::FP64)
.value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
......
......@@ -34,13 +34,15 @@ from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace
from distribute_transpiler import DistributeTranspiler
from distribute_transpiler_simple import SimpleDistributeTranspiler
from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close)
import clip
from memory_optimization_transpiler import memory_optimize
import profiler
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'io',
'initializer',
'layers',
......
......@@ -68,7 +68,7 @@ def _infer_var_data_type_(grad_var_name, block):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype())
else:
grad_var.set_dtype(core.DataType.FP32)
grad_var.set_dtype(core.VarDesc.VarType.FP32)
def _all_in_set_(cands, s):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: Variables: make_channel
# TODO: Operators: send, close_channel, recv, go, select
from layers.control_flow import BlockGuard
from layer_helper import LayerHelper
__all__ = [
'Go',
'make_channel',
'channel_send',
'channel_recv',
'channel_close',
]
class Go(BlockGuard):
def __init__(self, name=None):
self.helper = LayerHelper("go", name=name)
super(Go, self).__init__(self.helper.main_program)
def __enter__(self):
super(Go, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
self.construct_go_op()
return super(Go, self).__exit__(exc_type, exc_val, exc_tb)
def construct_go_op(self):
main_program = self.helper.main_program
go_block = main_program.current_block()
parent_block = main_program.block(main_program.current_block()
.parent_idx)
x_name_list = set()
out_vars = set()
for op in go_block.ops:
# Iterate over all operators, get all the inputs
# and add as input to the Go operator.
for iname in op.input_names:
for in_var_name in op.input(iname):
x_name_list.add(in_var_name)
# Iterate over all operators , get all the outputs
# add to the output list of Go operator only if
# they exist in the parent block.
for oname in op.output_names:
for out_var_name in op.output(oname):
if out_var_name in parent_block.vars:
out_vars.add(parent_block.var(out_var_name))
parent_block.append_op(
type='go',
inputs={'X': [parent_block.var(x_name) for x_name in x_name_list]},
outputs={'Out': out_vars},
attrs={'sub_block': go_block})
def make_channel(dtype, size=0):
return True
def channel_send(channel, value):
return True
def channel_recv(channel):
return True
def channel_close(channel):
return True
......@@ -27,13 +27,13 @@ class DataToLoDTensorConverter(object):
self.place = place
self.lod_level = lod_level
self.shape = shape
if dtype == core.DataType.FP32:
if dtype == core.VarDesc.VarType.FP32:
self.dtype = 'float32'
elif dtype == core.DataType.INT64:
elif dtype == core.VarDesc.VarType.INT64:
self.dtype = 'int64'
elif dtype == core.DataType.FP64:
elif dtype == core.VarDesc.VarType.FP64:
self.dtype = 'float64'
elif dtype == core.DataType.INT32:
elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32'
else:
raise ValueError("dtype must be any of [int32, float32, int64, "
......
......@@ -89,7 +89,7 @@ class Evaluator(object):
Args:
suffix(str): the state suffix.
dtype(str|core.DataType): the state data type
dtype(str|core.VarDesc.VarType): the state data type
shape(tuple|list): the shape of state
Returns: State variable
......
......@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype):
Args:
np_dtype(np.dtype): the data type in numpy
Returns(core.DataType): the data type in Paddle
Returns(core.VarDesc.VarType): the data type in Paddle
"""
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
return core.VarDesc.VarType.FP32
elif dtype == np.float64:
return core.DataType.FP64
return core.VarDesc.VarType.FP64
elif dtype == np.float16:
return core.DataType.FP16
return core.VarDesc.VarType.FP16
elif dtype == np.int32:
return core.DataType.INT32
return core.VarDesc.VarType.INT32
elif dtype == np.int16:
return core.DataType.INT16
return core.VarDesc.VarType.INT16
elif dtype == np.int64:
return core.DataType.INT64
return core.VarDesc.VarType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
return core.VarDesc.VarType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
......@@ -93,16 +93,19 @@ def dtype_is_floating(dtype):
"""
Check the data type is floating or not.
Args:
dtype(np.dtype|core.DataType): data type.
dtype(np.dtype|core.VarDesc.VarType): data type.
Could be numpy format or Paddle format
Returns(bool): True if data type is a float value
"""
if not isinstance(dtype, core.DataType):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return dtype in [core.DataType.FP16, core.DataType.FP32, core.DataType.FP64]
return dtype in [
core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP64
]
def _debug_string_(proto, throw_on_error=True):
......@@ -148,7 +151,7 @@ class Variable(object):
framework.proto for details.
shape(tuple|list|None): The shape of variable. -1 means the batch size.
Some kinds of variable do not contain shape, just set it to None.
dtype(np.dtype|core.DataType|str): The data type of variable.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable.
lod_level(int): The level of lod tensor. 0 means there is not a time
series data.
persistable(bool): True if the variable should be saved as check point.
......@@ -200,7 +203,7 @@ class Variable(object):
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
if not isinstance(dtype, core.DataType):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_dtype(dtype)
......
......@@ -614,7 +614,7 @@ class While(object):
if not isinstance(cond, Variable):
raise TypeError("condition should be a variable")
assert isinstance(cond, Variable)
if cond.dtype != core.DataType.BOOL:
if cond.dtype != core.VarDesc.VarType.BOOL:
raise TypeError("condition should be a bool variable")
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar")
......
......@@ -221,7 +221,7 @@ def embedding(input,
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
Returns:
Variable: The tensor variable storing the embeddings of the \
......
......@@ -17,7 +17,7 @@ from ..param_attr import ParamAttr
from ..framework import convert_np_dtype_to_dtype_
from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu
from ..core import DataType
from ..core import VarDesc
import numpy
__all__ = [
......@@ -199,10 +199,10 @@ def assign(input, output):
attrs={'scale': 1.0})
elif isinstance(input, numpy.ndarray):
dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == DataType.FP32:
if dtype == VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in input.flat]
elif dtype == DataType.INT32:
elif dtype == VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in input.flat]
else:
......@@ -236,7 +236,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
Args:
shape(tuple|list|None): Shape of the output tensor.
dtype(np.dtype|core.DataType|str): Data type of the output tensor.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor.
value(float): The constant value used to initialize the output tensor.
out(Variable): The output tensor.
force_cpu(True|False): data should be on CPU if set true.
......@@ -285,7 +285,7 @@ def fill_constant_batch_size_like(input,
Args:
input(Variable): Tensor whose dimensions will be used to get batch size
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor
input_dim_idx(int): Index of input's batch size dimension
output_dim_idx(int): Index of output's batch size dimension
......@@ -327,7 +327,7 @@ def ones(shape, dtype, force_cpu=False):
Args:
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
Returns:
Variable: The tensor variable storing the output
......@@ -351,7 +351,7 @@ def zeros(shape, dtype, force_cpu=False):
Args:
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
Returns:
Variable: The tensor variable storing the output
......
......@@ -20,13 +20,13 @@ from backward import _rename_arg_
from . import core
dtype_to_size = {
core.DataType.FP16: 2,
core.DataType.FP32: 4,
core.DataType.FP64: 8,
core.DataType.INT16: 2,
core.DataType.INT32: 4,
core.DataType.INT64: 8,
core.DataType.BOOL: 1
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1
}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
from paddle.v2.fluid.executor import Executor
class TestRoutineOp(unittest.TestCase):
def test_simple_routine(self):
ch = fluid.make_channel(dtype=bool)
with fluid.Go():
fluid.channel_send(ch, True)
result = fluid.channel_recv(ch)
fluid.channel_close(ch)
cpu = core.CPUPlace()
exe = Executor(cpu)
outs = exe.run(fetch_list=[result])
self.assertEqual(outs[0], True)
if __name__ == '__main__':
unittest.main()
......@@ -22,7 +22,7 @@ block = prog.current_block()
random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_dtypes(
[fluid.core.DataType.FP32, fluid.core.DataType.FP32])
[fluid.core.VarDesc.VarType.FP32, fluid.core.VarDesc.VarType.FP32])
create_random_data_generator_op = block.append_op(
type="create_random_data_generator",
......
......@@ -119,9 +119,9 @@ def get_numeric_gradient(place,
tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
tensor_to_check_dtype = tensor_to_check.dtype()
if tensor_to_check_dtype == core.DataType.FP32:
if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
tensor_to_check_dtype = np.float32
elif tensor_to_check_dtype == core.DataType.FP64:
elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
tensor_to_check_dtype = np.float64
else:
raise ValueError("Not supported data type " + str(
......
......@@ -140,9 +140,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if data is None:
if out_dtype == core.DataType.FP64:
if out_dtype == core.VarDesc.VarType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
elif out_dtype == core.VarDesc.VarType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
......
......@@ -24,8 +24,8 @@ class TestCastOp(op_test.OpTest):
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')}
self.attrs = {
'in_dtype': int(core.DataType.FP32),
'out_dtype': int(core.DataType.FP64)
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP64)
}
self.op_type = 'cast'
......
......@@ -26,7 +26,7 @@ class TestFillOp(OpTest):
self.attrs = {
'value': val.flatten().tolist(),
'shape': [100, 200],
'dtype': int(core.DataType.FP64)
'dtype': int(core.VarDesc.VarType.FP64)
}
self.outputs = {'Out': val.astype('float64')}
......
......@@ -97,9 +97,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if data is None:
if out_dtype == core.DataType.FP64:
if out_dtype == core.VarDesc.VarType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
elif out_dtype == core.VarDesc.VarType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
......
......@@ -38,7 +38,7 @@ class TestOneHotOp(OpTest):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)}
self.attrs = {'depth': depth, 'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
......
......@@ -36,7 +36,7 @@ class TestParameter(unittest.TestCase):
self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.dtype)
self.assertEqual(core.VarDesc.VarType.FP32, param.dtype)
self.assertEqual(0, param.block.idx)
exe = Executor(core.CPUPlace())
p = exe.run(main_program, fetch_list=[param])[0]
......
......@@ -131,8 +131,8 @@ class TestVarDesc(unittest.TestCase):
block = program_desc.block(0)
var = block.var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_dtype(core.DataType.INT32)
self.assertEqual(core.DataType.INT32, var.dtype())
var.set_dtype(core.VarDesc.VarType.INT32)
self.assertEqual(core.VarDesc.VarType.INT32, var.dtype())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
def test_multiple_dtype(self):
......@@ -141,7 +141,8 @@ class TestVarDesc(unittest.TestCase):
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
src_types = [
core.DataType.INT32, core.DataType.FP64, core.DataType.FP32
core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32
]
var.set_dtypes(src_types)
self.assertEqual(src_types, var.dtypes())
......
......@@ -20,7 +20,7 @@ import numpy as np
class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self):
DT = core.DataType
DT = core.VarDesc.VarType
convert = convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16"))
......@@ -36,13 +36,13 @@ class TestVariable(unittest.TestCase):
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertNotEqual(str(w), "")
self.assertEqual(core.DataType.FP64, w.dtype)
self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.dtype)
self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
......
......@@ -12,7 +12,7 @@ with newer version compilers cannot work with those with older
versions. The suggested building environment is as old as CentOS 5.
However, PaddlePaddle relies on CUDA, and the earlies version of
[CentOS works with CUDA is 6](https://hub.docker.com/r/nvidia/cuda/).
So, here we provide a Docker image basing on CentOS 6 and CUDA for
So, here we provide a Docker image based on CentOS 6 and CUDA for
building PaddlePaddle and making the release supports "as-manylinux as
possible." or "sufficiently many Linux" according to [this
discussion](https://mail.python.org/pipermail/wheel-builders/2016-July/000175.html).
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册