提交 9ee8a0d0 编写于 作者: C chengduoZH

remove conflict

...@@ -28,3 +28,4 @@ cmake_install.cmake ...@@ -28,3 +28,4 @@ cmake_install.cmake
paddle/.timestamp paddle/.timestamp
python/paddlepaddle.egg-info/ python/paddlepaddle.egg-info/
paddle/pybind/pybind.h paddle/pybind/pybind.h
python/paddle/v2/framework/tests/tmp/*
...@@ -15,7 +15,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) ...@@ -15,7 +15,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc) cc_library(scope SRCS scope.cc DEPS glog)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
...@@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) ...@@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......
...@@ -34,5 +34,25 @@ inline DataType ToDataType(std::type_index type) { ...@@ -34,5 +34,25 @@ inline DataType ToDataType(std::type_index type) {
} }
} }
template <typename Visitor>
inline void VisitDataType(DataType type, Visitor visitor) {
switch (type) {
case DataType::FP32:
visitor.template operator()<float>();
break;
case DataType::FP64:
visitor.template operator()<double>();
break;
case DataType::INT32:
visitor.template operator()<int>();
break;
case DataType::INT64:
visitor.template operator()<int64_t>();
break;
default:
PADDLE_THROW("Not supported");
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,15 +16,51 @@ limitations under the License. */ ...@@ -16,15 +16,51 @@ limitations under the License. */
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <unordered_map> #include <unordered_map>
#include "glog/logging.h"
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h" #include "paddle/framework/program_desc.h"
#include "paddle/framework/shape_inference.h"
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDescBind;
class BlockDescBind;
class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDescBind &op,
const BlockDescBind &block);
bool HasInput(const std::string &name) const override;
bool HasOutput(const std::string &name) const override;
bool HasInputs(const std::string &name) const override;
bool HasOutputs(const std::string &name) const override;
DDim GetInputDim(const std::string &name) const override;
void SetOutputDim(const std::string &name, const DDim &dim) override;
AttrReader Attrs() const override;
const std::vector<std::string> &Inputs(
const std::string &name) const override;
const std::vector<std::string> &Outputs(
const std::string &name) const override;
private:
DDim GetDim(const std::string &name) const override;
void SetDim(const std::string &name, const DDim &dim) override;
const OpDescBind &op_;
const BlockDescBind &block_;
};
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs) { const AttributeMap &attrs) {
...@@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const { ...@@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
} }
} }
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
const OpDescBind &op, const BlockDescBind &block)
: op_(op), block_(block) {}
bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
const std::vector<std::string> &input_names = op_.Input(name);
auto length = input_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVarRecursive(input_names[0]);
}
bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
const std::vector<std::string> &output_names = op_.Output(name);
auto length = output_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVarRecursive(output_names[0]);
}
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
const std::vector<std::string> &input_names = op_.Input(name);
if (input_names.empty()) {
return false;
}
for (auto &input : input_names) {
if (!block_.HasVarRecursive(input)) return false;
}
return true;
}
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
const std::vector<std::string> &output_names = op_.Output(name);
if (output_names.empty()) {
return false;
}
for (auto &output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
}
return true;
}
DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void CompileTimeInferShapeContext::SetOutputDim(const std::string &name,
const DDim &dim) {
SetOutputsDim(name, {dim});
}
AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap());
}
const std::vector<std::string> &CompileTimeInferShapeContext::Inputs(
const std::string &name) const {
return op_.Input(name);
}
const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
const std::string &name) const {
return op_.Output(name);
}
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
return framework::make_ddim(var->Shape());
}
void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -107,6 +107,8 @@ class OpDescBind { ...@@ -107,6 +107,8 @@ class OpDescBind {
void InferVarType(BlockDescBind *block) const; void InferVarType(BlockDescBind *block) const;
void MarkAsTarget() { desc_.set_is_target(true); }
void Flush(); void Flush();
private: private:
......
...@@ -29,6 +29,7 @@ limitations under the License. */ ...@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -161,6 +162,10 @@ class OpKernelRegistrar : public Registrar { ...@@ -161,6 +162,10 @@ class OpKernelRegistrar : public Registrar {
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class); op_maker_class);
#define REGISTER_OP_WITH_KERNEL(op_type, ...) \
REGISTER_OPERATOR(op_type, ::paddle::framework::OperatorWithKernel, \
##__VA_ARGS__)
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class) REGISTER_OPERATOR(op_type, op_class, op_maker_class)
...@@ -223,6 +228,10 @@ class OpKernelRegistrar : public Registrar { ...@@ -223,6 +228,10 @@ class OpKernelRegistrar : public Registrar {
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); USE_OP_DEVICE_KERNEL(op_type, CPU);
#define USE_GPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, GPU)
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type) USE_OP_KERNEL(op_type)
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include "paddle/framework/shape_inference.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -273,5 +274,137 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -273,5 +274,137 @@ bool OpSupportGPU(const std::string& op_type) {
return false; return false;
} }
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override {
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasOutput(const std::string& name) const override {
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasInputs(const std::string& name) const override {
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false;
}
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name) const override {
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false;
}
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) {
return false;
}
}
return true;
}
DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim);
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}
private:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
void SetDim(const std::string& name, const DDim& dim) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
const OperatorBase& op_;
const Scope& scope_;
};
void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
VLOG(3) << "Running operator " << this->Type();
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
ExecutionContext ctx(*this, scope, dev_ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
}
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_iter = kernels.find(kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_, kernel_key);
}
kernel_iter->second->Compute(ctx);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,7 +29,6 @@ limitations under the License. */ ...@@ -29,7 +29,6 @@ limitations under the License. */
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
...@@ -123,7 +122,7 @@ class OperatorBase { ...@@ -123,7 +122,7 @@ class OperatorBase {
protected: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)opear // I (Inputs)
// O (Outputs) // O (Outputs)
// OG (Output Gradients) // OG (Output Gradients)
VariableNameMap inputs_; VariableNameMap inputs_;
...@@ -288,6 +287,16 @@ class ExecutionContext { ...@@ -288,6 +287,16 @@ class ExecutionContext {
return device_context_; return device_context_;
} }
//! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get actual name vector for this output.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const { const platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
...@@ -317,226 +326,6 @@ template <> ...@@ -317,226 +326,6 @@ template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const; const std::string& name) const;
class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), block_(block) {}
bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
auto length = input_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVarRecursive(input_names[0]);
}
bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
auto length = output_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVarRecursive(output_names[0]);
}
bool HasInputs(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
if (input_names.empty()) {
return false;
}
for (auto& input : input_names) {
if (!block_.HasVarRecursive(input)) return false;
}
return true;
}
bool HasOutputs(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
if (output_names.empty()) {
return false;
}
for (auto& output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
}
return true;
}
DDim GetInputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void SetInputDim(const std::string& name, const DDim& dim) override {
SetInputsDim(name, {dim});
}
DDim GetOutputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetOutputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetOutputsDim(name, {dim});
}
AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Input(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Output(name);
}
private:
DDim GetDim(const std::string& name) const override {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
return framework::make_ddim(var->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
}
const OpDescBind& op_;
const BlockDescBind& block_;
};
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override {
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasOutput(const std::string& name) const override {
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasInputs(const std::string& name) const override {
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false;
}
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name) const override {
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false;
}
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) {
return false;
}
}
return true;
}
DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name));
}
void SetInputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Input(name), dim);
}
DDim GetOutputDim(const std::string& name) const override {
return GetDim(op_.Output(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim);
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}
private:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
void SetDim(const std::string& name, const DDim& dim) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
const OperatorBase& op_;
const Scope& scope_;
};
class OpKernelBase { class OpKernelBase {
public: public:
/** /**
...@@ -595,32 +384,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -595,32 +384,7 @@ class OperatorWithKernel : public OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final;
VLOG(3) << "Running operator " << this->Type();
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
ExecutionContext ctx(*this, scope, dev_ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW("op[%s] has no kernel", type_);
}
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_iter = kernels.find(kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
kernel_key);
}
kernel_iter->second->Compute(ctx);
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() { AllOpKernels() {
...@@ -644,6 +408,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -644,6 +408,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const { virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
VLOG(3) << "Default IndicateDataType " << this->Type();
auto& scope = ctx.scope(); auto& scope = ctx.scope();
int data_type = -1; int data_type = -1;
for (auto& input : this->inputs_) { for (auto& input : this->inputs_) {
......
...@@ -49,6 +49,13 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ...@@ -49,6 +49,13 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
} }
} }
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}
ProgramDescBind::ProgramDescBind(const std::string &binary_str) { ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
......
...@@ -29,6 +29,8 @@ class ProgramDescBind { ...@@ -29,6 +29,8 @@ class ProgramDescBind {
public: public:
ProgramDescBind(); ProgramDescBind();
explicit ProgramDescBind(const ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o); ProgramDescBind(const ProgramDescBind &o);
explicit ProgramDescBind(const std::string &binary_str); explicit ProgramDescBind(const std::string &binary_str);
......
...@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) { ...@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
return false; return false;
} }
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
...@@ -91,8 +91,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { ...@@ -91,8 +91,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// we reverse the should_run vector // we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end()); std::reverse(should_run.begin(), should_run.end());
output = input; *output = input;
auto* op_field = output.mutable_blocks(block_id)->mutable_ops(); auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
op_field->Clear(); op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) { for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) { if (should_run[i]) {
...@@ -101,7 +101,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { ...@@ -101,7 +101,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
} }
} }
void Prune(const ProgramDesc& input, ProgramDesc& output) { // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0); prune_impl(input, output, 0);
} }
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output); void Prune(const ProgramDesc& input, ProgramDesc* output);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,11 +59,11 @@ TEST(Prune, one_operator) { ...@@ -59,11 +59,11 @@ TEST(Prune, one_operator) {
f::ProgramDesc *pdesc = program.Proto(); f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
} }
...@@ -81,7 +81,7 @@ TEST(Prune, forward) { ...@@ -81,7 +81,7 @@ TEST(Prune, forward) {
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned; f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
} }
} }
...@@ -100,7 +100,7 @@ TEST(Prune, multi_input_op) { ...@@ -100,7 +100,7 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
} }
...@@ -116,7 +116,7 @@ TEST(Prune, multi_output_op) { ...@@ -116,7 +116,7 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
} }
...@@ -133,6 +133,6 @@ TEST(Prune, multi_target) { ...@@ -133,6 +133,6 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
} }
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for call_once #include <mutex> // for call_once
#include "glog/logging.h"
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
namespace paddle { namespace paddle {
...@@ -23,7 +24,10 @@ namespace framework { ...@@ -23,7 +24,10 @@ namespace framework {
Scope::~Scope() { Scope::~Scope() {
DropKids(); DropKids();
for (auto& kv : vars_) delete kv.second; for (auto& kv : vars_) {
VLOG(3) << "Destroy variable " << kv.first;
delete kv.second;
}
} }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
...@@ -38,6 +42,7 @@ Variable* Scope::Var(const std::string& name) { ...@@ -38,6 +42,7 @@ Variable* Scope::Var(const std::string& name) {
} }
Variable* v = new Variable(); Variable* v = new Variable();
vars_[name] = v; vars_[name] = v;
VLOG(3) << "Create variable " << name << " on scope";
v->name_ = &(vars_.find(name)->first); v->name_ = &(vars_.find(name)->first);
return v; return v;
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/shape_inference.h"
namespace paddle {
namespace framework {
std::vector<framework::DDim> InferShapeContext::GetInputsDim(
const std::string &name) const {
const std::vector<std::string> &names = Inputs(name);
return GetDims(names);
}
void InferShapeContext::SetOutputsDim(
const std::string &name, const std::vector<framework::DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
void InferShapeContext::ShareLoD(const std::string &in, const std::string &out,
size_t i, size_t j) const {}
std::vector<framework::DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
SetDim(names[i], dims[i]);
}
}
} // namespace framework
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
namespace paddle { namespace paddle {
...@@ -21,7 +22,7 @@ namespace framework { ...@@ -21,7 +22,7 @@ namespace framework {
class InferShapeContext { class InferShapeContext {
public: public:
virtual ~InferShapeContext() {} virtual ~InferShapeContext() = default;
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
...@@ -29,57 +30,32 @@ class InferShapeContext { ...@@ -29,57 +30,32 @@ class InferShapeContext {
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
virtual framework::DDim GetInputDim(const std::string &name) const = 0; virtual framework::DDim GetInputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetInputsDim(const std::string &name) const {
const std::vector<std::string> &names = Inputs(name); std::vector<framework::DDim> GetInputsDim(const std::string &name) const;
return GetDims(names);
}
virtual void SetInputDim(const std::string &name,
const framework::DDim &dim) = 0;
void SetInputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Inputs(name);
SetDims(names, dims);
}
virtual framework::DDim GetOutputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetOutputsDim(const std::string &name) const {
const std::vector<std::string> &names = Outputs(name);
return GetDims(names);
}
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name, void SetOutputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) { const std::vector<framework::DDim> &dims);
auto &names = Outputs(name);
SetDims(names, dims);
}
virtual AttrReader Attrs() const = 0; virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs( virtual const std::vector<std::string> &Inputs(
const std::string &name) const = 0; const std::string &name) const = 0;
virtual const std::vector<std::string> &Outputs( virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0; const std::string &name) const = 0;
// TODO(qiao) implement this function // TODO(qiao) implement this function
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const {} size_t j = 0) const;
protected: protected:
virtual framework::DDim GetDim(const std::string &name) const = 0; virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
std::vector<framework::DDim> GetDims( std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const;
std::vector<framework::DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDims(const std::vector<std::string> &names, void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) { const std::vector<framework::DDim> &dims);
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
SetDim(names[i], dims[i]);
}
}
}; };
} // namespace framework } // namespace framework
......
...@@ -126,11 +126,16 @@ class Tensor { ...@@ -126,11 +126,16 @@ class Tensor {
inline Tensor Slice(const int& begin_idx, const int& end_idx) const; inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const { platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder"); PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::place() is called.");
return holder_->place(); return holder_->place();
} }
std::type_index type() const { return holder_->type(); } std::type_index type() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called.");
return holder_->type();
}
size_t memory_size() const; size_t memory_size() const;
......
add_subdirectory(detail) add_subdirectory(detail)
cc_library(memory SRCS memory.cc) cc_library(memory SRCS memory.cc DEPS place)
cc_library(memcpy SRCS memcpy.cc) cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory cc_library(paddle_memory
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/memory/detail/meta_cache.h" #include "paddle/memory/detail/meta_cache.h"
#include "glog/logging.h"
#include "paddle/memory/detail/memory_block.h" #include "paddle/memory/detail/memory_block.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
...@@ -28,7 +29,9 @@ Metadata MetadataCache::load(const MemoryBlock* block) { ...@@ -28,7 +29,9 @@ Metadata MetadataCache::load(const MemoryBlock* block) {
PADDLE_ASSERT(existing_metadata->second.check_guards()); PADDLE_ASSERT(existing_metadata->second.check_guards());
return existing_metadata->second; return existing_metadata->second;
} else { } else {
PADDLE_ASSERT(reinterpret_cast<const Metadata*>(block)->check_guards()); auto* meta = reinterpret_cast<const Metadata*>(block);
VLOG(3) << "Load MetaData type=" << meta->type;
PADDLE_ASSERT(meta->check_guards());
return *reinterpret_cast<const Metadata*>(block); return *reinterpret_cast<const Metadata*>(block);
} }
} }
......
...@@ -39,11 +39,15 @@ BuddyAllocator* GetCPUBuddyAllocator() { ...@@ -39,11 +39,15 @@ BuddyAllocator* GetCPUBuddyAllocator() {
template <> template <>
void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size) { void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size) {
return GetCPUBuddyAllocator()->Alloc(size); VLOG(3) << "Allocate " << size << " bytes on " << platform::Place(place);
void* p = GetCPUBuddyAllocator()->Alloc(size);
VLOG(3) << " pointer=" << p;
return p;
} }
template <> template <>
void Free<platform::CPUPlace>(platform::CPUPlace place, void* p) { void Free<platform::CPUPlace>(platform::CPUPlace place, void* p) {
VLOG(3) << "Free pointer=" << p << " on " << platform::Place(place);
GetCPUBuddyAllocator()->Free(p); GetCPUBuddyAllocator()->Free(p);
} }
......
...@@ -97,6 +97,13 @@ function(op_library TARGET) ...@@ -97,6 +97,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif() endif()
# nccl_op contains several operators
if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
endif()
# reduce_op contains several operators # reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op") if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1) set(pybind_flag 1)
...@@ -128,6 +135,7 @@ function(op_library TARGET) ...@@ -128,6 +135,7 @@ function(op_library TARGET)
endfunction() endfunction()
add_subdirectory(math) add_subdirectory(math)
add_subdirectory(nccl)
set(DEPS_OPS set(DEPS_OPS
recurrent_op recurrent_op
...@@ -138,6 +146,7 @@ set(DEPS_OPS ...@@ -138,6 +146,7 @@ set(DEPS_OPS
pool_op pool_op
pool_with_index_op pool_with_index_op
conv_op conv_op
nccl_op
sequence_conv_op sequence_conv_op
lstm_op) lstm_op)
...@@ -151,6 +160,9 @@ op_library(conv_op DEPS vol2col) ...@@ -151,6 +160,9 @@ op_library(conv_op DEPS vol2col)
op_library(sum_op DEPS net_op selected_rows_functor) op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling) op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project) op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstm_op DEPS sequence2batch lstm_compute)
...@@ -166,4 +178,8 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) ...@@ -166,4 +178,8 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
if(WITH_GPU)
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
...@@ -70,7 +70,5 @@ information, or not. But the output only shares the LoD with input `Inference`. ...@@ -70,7 +70,5 @@ information, or not. But the output only shares the LoD with input `Inference`.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker); REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
accuracy, ops::AccuracyKernel<paddle::platform::CPUPlace, float>, accuracy, ops::AccuracyKernel<paddle::platform::CPUPlace, int>,
ops::AccuracyKernel<paddle::platform::CPUPlace, int>,
ops::AccuracyKernel<paddle::platform::CPUPlace, double>,
ops::AccuracyKernel<paddle::platform::CPUPlace, int64_t>); ops::AccuracyKernel<paddle::platform::CPUPlace, int64_t>);
...@@ -81,7 +81,5 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -81,7 +81,5 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>, REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<int>,
paddle::operators::AccuracyOpCUDAKernel<double>,
paddle::operators::AccuracyOpCUDAKernel<int>,
paddle::operators::AccuracyOpCUDAKernel<int64_t>); paddle::operators::AccuracyOpCUDAKernel<int64_t>);
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
...@@ -64,6 +65,9 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -64,6 +65,9 @@ class BatchNormOp : public framework::OperatorWithKernel {
(tensor_format == TensorFormat::NCHW ? x_dims[1] (tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"Input x must have 3 to 5 dimensions.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
...@@ -108,10 +112,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -108,10 +112,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
"Store the global Variance when training"); "Store the global Variance when training");
AddOutput("SavedMean", AddOutput("SavedMean",
"Mean of the current mini batch, " "Mean of the current mini batch, "
"will apply to output when training"); "will apply to output when training")
.AsIntermediate();
AddOutput("SavedVariance", AddOutput("SavedVariance",
"Variance of the current mini batch, " "Variance of the current mini batch, "
"will apply to output when training"); "will apply to output when training")
.AsIntermediate();
AddComment(R"DOC( AddComment(R"DOC(
https://arxiv.org/pdf/1502.03167.pdf https://arxiv.org/pdf/1502.03167.pdf
...@@ -135,7 +141,6 @@ class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> { ...@@ -135,7 +141,6 @@ class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5"); "The Input dim size should be between 3 and 5");
const int N = x_dims[0]; const int N = x_dims[0];
...@@ -289,6 +294,25 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -289,6 +294,25 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
VLOG(3) << "IndicateDataType " << this->Type();
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::ToDataType(t->type());
}
}; };
template <typename T> template <typename T>
......
...@@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
math::SetConstant<platform::GPUPlace, T> functor; math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0); functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0); functor(ctx.device_context(), saved_variance, 0);
// FIXME(qiao) should not set zero self
functor(ctx.device_context(), mean_out, 0);
functor(ctx.device_context(), variance_out, 0);
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
...@@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T> ...@@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T>
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif #endif
std::vector<int> dims = {N, C, H, W, D}; std::vector<int> dims;
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C}; std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cast_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CastOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensor of cast op");
AddOutput("Out", "the output tensor of cast op");
AddComment(R"DOC(Cast operator.
cast the input tensor to other data type.
)DOC");
AddAttr<int>("out_data_type", "output data type");
AddAttr<int>("in_data_type", "input data type");
}
};
class CastOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "The input of cast op must be set");
PADDLE_ENFORCE(context->HasOutput("Out"),
"The output of cast op must be set");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class CastOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto grad = new framework::OpDescBind();
grad->SetType("cast");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("X"));
grad->SetAttr("out_data_type", GetAttr("in_data_type"));
grad->SetAttr("in_data_type", GetAttr("out_data_type"));
return std::unique_ptr<framework::OpDescBind>(grad);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUPlace;
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cast_op.h"
template <typename T>
using CastOpKernel =
paddle::operators::CastOpKernel<paddle::platform::GPUPlace, T>;
REGISTER_OP_GPU_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename Place, typename InT>
struct CastOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::DeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::DeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
auto* in_begin = in_->data<InT>();
auto numel = in_->numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::Transform<Place> trans;
trans(ctx_, in_begin, in_end, out_begin,
CastOpTransformFunctor<InT, OutT>());
}
};
template <typename Place, typename InT>
class CastOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_data_type")),
CastOpFunctor<Place, InT>(in, out, context.device_context()));
}
};
} // namespace operators
} // namespace paddle
...@@ -21,7 +21,7 @@ namespace { ...@@ -21,7 +21,7 @@ namespace {
template <typename T> template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const int* label, const int N, const int64_t* label, const int N,
const int D) { const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) { // CUDA_1D_KERNEL_LOOP(i, N) {
...@@ -77,8 +77,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -77,8 +77,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace()); T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int batch_size = x->dims()[0]; int64_t batch_size = x->dims()[0];
int class_num = x->dims()[1]; int64_t class_num = x->dims()[1];
int block = 512; int block = 512;
int grid = (batch_size * class_num + block - 1) / block; int grid = (batch_size * class_num + block - 1) / block;
...@@ -93,7 +93,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -93,7 +93,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} else { } else {
math::SetConstant<platform::GPUPlace, T> functor; math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), dx, 0); functor(ctx.device_context(), dx, 0);
auto* label_data = label->data<int>(); auto* label_data = label->data<int64_t>();
grid = (batch_size + block - 1) / block; grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<< CrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
......
...@@ -54,7 +54,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -54,7 +54,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(ctx.GetPlace()); T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
int class_num = x->dims()[1]; int64_t class_num = x->dims()[1];
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
auto x_mat = EigenMatrix<T>::From(*x); auto x_mat = EigenMatrix<T>::From(*x);
auto dy_mat = EigenMatrix<T>::From(*dy); auto dy_mat = EigenMatrix<T>::From(*dy);
...@@ -62,20 +62,20 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -62,20 +62,20 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
auto dx_mat = EigenMatrix<T>::From(*dx); auto dx_mat = EigenMatrix<T>::From(*dx);
dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) = dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-(lbl_mat * dy_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) / -(lbl_mat *
x_mat); dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
} else { } else {
int batch_size = x->dims()[0]; int64_t batch_size = x->dims()[0];
const T* dy_data = dy->data<T>(); const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const int* label_data = label->data<int>(); const int64_t* label_data = label->data<int64_t>();
math::SetConstant<platform::CPUPlace, T> functor; math::SetConstant<platform::CPUPlace, T> functor;
functor(ctx.device_context(), dx, 0); functor(ctx.device_context(), dx, 0);
for (int i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
int index = i * class_num + label_data[i]; int64_t index = i * class_num + label_data[i];
dx_data[index] = -dy_data[i] / x_data[index]; dx_data[index] = -dy_data[i] / x_data[index];
} }
} }
......
...@@ -41,7 +41,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -41,7 +41,7 @@ class FeedOp : public framework::OperatorBase {
auto col = Attr<int>("col"); auto col = Attr<int>("col");
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var" VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var "
<< out_name; << out_name;
auto &feed_list = feed_var->Get<framework::FeedFetchList>(); auto &feed_list = feed_var->Get<framework::FeedFetchList>();
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/lookup_table_op.h" #include "paddle/operators/lookup_table_op.h"
#include "paddle/framework/var_type_inference.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,6 +61,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,6 +61,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Ids must be a column vector with rank = 2." "Ids must be a column vector with rank = 2."
"The 2nd dimension size must be 1"); "The 2nd dimension size must be 1");
AddOutput("Out", "The lookup results, which have the same type with W."); AddOutput("Out", "The lookup results, which have the same type with W.");
AddAttr<bool>("is_sparse", "Sparse update").SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
This operator is used to perform lookups on the parameter W, This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor. then concatenated into a dense tensor.
...@@ -70,6 +72,15 @@ or not. And the output only shares the LoD with input `Ids`. ...@@ -70,6 +72,15 @@ or not. And the output only shares the LoD with input `Ids`.
} }
}; };
class LookupTableOpGradDescMaker
: public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker;
protected:
virtual std::string GradOpType() const { return "lookup_table_grad"; }
};
class LookupTableOpGrad : public framework::OperatorWithKernel { class LookupTableOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -86,12 +97,35 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -86,12 +97,35 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
} }
}; };
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(out_var_name)->SetType(framework::VarDesc::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::VarDesc::LOD_TENSOR);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, REGISTER_OPERATOR(lookup_table, ops::LookupTableOp,
lookup_table_grad, ops::LookupTableOpGrad); ops::LookupTableOpGradDescMaker, ops::LookupTableOpMaker);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>); ops::LookupTableOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -14,22 +11,21 @@ ...@@ -14,22 +11,21 @@
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/lookup_table_op.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTable(T* output, const T* table, const int32_t* ids, __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
const int N, const int K, const int D) { const int64_t N, const int64_t K, const int64_t D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) { while (idy < K) {
int id = ids[idy]; int64_t id = ids[idy];
PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N); PADDLE_ASSERT(id < N);
T* out = output + idy * D; T* out = output + idy * D;
...@@ -42,8 +38,9 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids, ...@@ -42,8 +38,9 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
} }
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
const int N, const int K, const int D) { const int64_t N, const int64_t K,
const int64_t D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * GridDimX;
...@@ -71,7 +68,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -71,7 +68,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel(); size_t K = ids_t->numel();
auto ids = ids_t->data<int32_t>(); auto ids = ids_t->data<int64_t>();
auto table = table_t->data<T>(); auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto output = output_t->mutable_data<T>(context.GetPlace());
...@@ -88,27 +85,63 @@ template <typename T> ...@@ -88,27 +85,63 @@ template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel<T> { class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); bool is_sparse = context.Attr<bool>("is_sparse");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); if (is_sparse) {
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); auto* ids = context.Input<Tensor>("Ids");
auto* table = context.Input<Tensor>("W");
int N = d_table_t->dims()[0]; auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
int D = d_table_t->dims()[1]; auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
int K = ids_t->numel();
const int32_t* ids = ids_t->data<int32_t>(); auto* ids_data = ids->data<int64_t>();
const T* d_output = d_output_t->data<T>(); auto ids_dim = ids->dims();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
auto t = framework::EigenVector<T>::Flatten(*d_table_t); context.device_context())
t.device(context.GetEigenDevice<platform::GPUPlace>()) = .stream();
t.constant(static_cast<T>(0)); // copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
dim3 threads(128, 8); new_rows.resize(ids_dim[0]);
dim3 grids(8, 1); auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace());
LookupTableGrad<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>( memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data,
ids_dim[0] * sizeof(int64_t), stream);
d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());
auto* d_table_data = d_table_value->data<T>();
auto* d_output_data = d_output->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel(), stream);
} else {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t* ids = ids_t->data<int64_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8,
8><<<grids, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context()) context.device_context())
.stream()>>>(d_table, d_output, ids, N, K, D); .stream()>>>(d_table, d_output, ids, N, K, D);
}
} }
}; };
...@@ -116,6 +149,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -116,6 +149,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>); REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableCUDAKernel<double>);
ops::LookupTableGradCUDAKernel<float>); REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel<float>,
ops::LookupTableGradCUDAKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,12 +12,15 @@ ...@@ -15,12 +12,15 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/selected_rows.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows;
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
...@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
int N = table_t->dims()[0]; int N = table_t->dims()[0];
int D = table_t->dims()[1]; int D = table_t->dims()[1];
auto ids = ids_t->data<int32_t>(); auto ids = ids_t->data<int64_t>();
auto table = table_t->data<T>(); auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) { for (int64_t i = 0; i < ids_t->numel(); ++i) {
...@@ -47,25 +47,55 @@ template <typename T> ...@@ -47,25 +47,55 @@ template <typename T>
class LookupTableGradKernel : public framework::OpKernel<T> { class LookupTableGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); bool is_sparse = context.Attr<bool>("is_sparse");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); if (is_sparse) {
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); auto* ids = context.Input<Tensor>("Ids");
auto* table = context.Input<Tensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
int N = d_table_t->dims()[0]; auto* ids_data = ids->data<int64_t>();
int D = d_table_t->dims()[1]; auto ids_dim = ids->dims();
auto ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t); framework::Vector<int64_t> new_rows;
t.device(context.GetEigenDevice<platform::CPUPlace>()) = new_rows.reserve(ids_dim[0]);
t.constant(static_cast<T>(0)); for (int64_t i = 0; i < ids_dim[0]; i++) {
new_rows.push_back(ids_data[i]);
}
d_table->set_rows(new_rows);
for (int64_t i = 0; i < ids_t->numel(); ++i) { auto* d_table_value = d_table->mutable_value();
PADDLE_ENFORCE_LT(ids[i], N); d_table_value->Resize({ids_dim[0], table->dims()[1]});
PADDLE_ENFORCE_GE(ids[i], 0); d_table_value->mutable_data<T>(context.GetPlace());
for (int j = 0; j < D; ++j) {
d_table[ids[i] * D + j] += d_output[i * D + j]; d_table->set_height(table->dims()[0]);
auto* d_output_data = d_output->data<T>();
auto* d_table_data = d_table_value->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto* ids = context.Input<Tensor>("Ids");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<Tensor>(framework::GradVarName("W"));
auto* table = context.Input<Tensor>("W");
auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
int N = table->dims()[0];
int D = d_output->dims()[1];
auto* d_output_data = d_output->data<T>();
auto* d_table_data = d_table->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids->numel(); ++i) {
PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0);
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] = d_output_data[i * D + j];
}
} }
} }
} }
......
...@@ -44,7 +44,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> { ...@@ -44,7 +44,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
const T* prob_data = prob->data<T>(); const T* prob_data = prob->data<T>();
T* loss_data = out->data<T>(); T* loss_data = out->data<T>();
const int* label_data = labels->data<int>(); const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index])); loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
......
...@@ -20,7 +20,7 @@ namespace math { ...@@ -20,7 +20,7 @@ namespace math {
namespace { namespace {
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int N, const int D) { const int N, const int D) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
...@@ -115,7 +115,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> { ...@@ -115,7 +115,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
loss_data, prob_data, label_data, class_num); loss_data, prob_data, label_data, class_num);
} else { } else {
const int* label_data = labels->data<int>(); const int64_t* label_data = labels->data<int64_t>();
int block = 512; int block = 512;
int grid = (batch_size + block - 1) / block; int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<< CrossEntropyKernel<T><<<
......
if(WITH_GPU)
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace platform {} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace platform {
constexpr int kInvalidGPUId = -1;
struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
Communicator() {}
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
}
~Communicator() {
for (size_t i = 0; i < comms_.size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy(comms_[i]);
}
}
DISABLE_COPY_AND_ASSIGN(Communicator);
};
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
// NCCLinitOp
class NCCLInitOp : public framework::OperatorBase {
public:
NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
}
platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus);
}
};
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddAttr<std::vector<int>>("gpus", "gpu id lists");
AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32);
AddComment(R"DOC(
create communicator.
)DOC");
}
};
// AllReduceOp
class NCCLAllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(
ctx->HasInput("Communicator"),
" Input(Communicator) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of AllReduce op input should not be NULL");
auto x_dims = ctx->GetInputsDim("X");
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction.");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// ReduceOp
class NCCLReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Reduce op input should not be NULL");
PADDLE_ENFORCE(
ctx->HasInput("Communicator"),
" Input(Communicator) of Reduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of Reduce op input should not be NULL");
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// BcastOp
class NCCLBcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// AllreduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
AllReduce the input tensors.
)DOC");
}
};
// ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op");
AddAttr<std::string>("reduction",
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddAttr<int>("root",
"root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC(
Reduce the tensors)DOC");
}
};
// BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLBcastOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast");
AddAttr<int>("root",
"root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC(
Bcast the tensors.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp,
ops::NCCLBcastOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
ops::NCCLReduceOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenseshashernless required by applicable law or agreed
to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <functional>
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Communicator;
using framework::LoDTensor;
template <typename Type>
class NCCLTypeWrapper;
template <>
class NCCLTypeWrapper<float> {
public:
static const ncclDataType_t type = ncclFloat;
};
template <>
class NCCLTypeWrapper<double> {
public:
static const ncclDataType_t type = ncclDouble;
};
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : "
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : "
<< " finished allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
}
}
};
template <typename T>
class NCCLReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size();
}
T* recvbuffer = nullptr;
if (root == gpu_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
}
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
}
}
};
template <typename T>
class NCCLBcastKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
if (idx == root) {
auto ins = ctx.MultiInput<LoDTensor>("X");
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
<< ins[i]->numel();
VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
}
} else {
auto outs = ctx.MultiOutput<LoDTensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) {
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
<< framework::product(outs[i]->dims());
PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
<< outs[i]->numel();
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclBcast, ops::NCCLBcastKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <mutex>
#include <thread>
#include <utility>
#include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/var_desc.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h"
USE_NO_KERNEL_OP(ncclInit);
USE_GPU_ONLY_OP(ncclAllReduce);
USE_GPU_ONLY_OP(ncclReduce);
USE_GPU_ONLY_OP(ncclBcast);
namespace f = paddle::framework;
namespace p = paddle::platform;
static std::vector<int> gpu_list;
// test data amount
const f::DDim kDims = {100, 100};
// nccl op common tester, init communicator.
class NCCLTester : public ::testing::Test {
public:
virtual void SetUp() override {
cpu_ctx = new p::CPUDeviceContext(p::CPUPlace());
for (size_t i = 0; i < gpu_list.size(); ++i) {
p::GPUPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
}
NCCLInitOp();
}
virtual void TearDown() override {
for (auto &device_context : dev_ctxs) {
delete device_context;
}
}
void NCCLInitOp() {
std::unique_ptr<f::OpDescBind> op1(new f::OpDescBind);
op1->SetType("ncclInit");
op1->SetOutput("Communicator", {"comm"});
op1->SetAttr("gpus", {gpu_list});
auto *var = g_scope.Var("comm");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *cpu_ctx);
VLOG(1) << "NCCLInitOp finished.";
}
template <class T>
void PerThreadProgram(int gpu_id, const f::OpDescBind &op_desc,
f::Scope *scope) {
std::unique_lock<std::mutex> lk(mu);
const f::OpDescBind *op1 = &op_desc;
p::GPUPlace place(gpu_id);
auto &ctx = dev_ctxs.at(gpu_id);
auto *send_tensor = scope->Var("st")->GetMutable<f::LoDTensor>();
auto *recv_tensor = scope->Var("rt")->GetMutable<f::LoDTensor>();
if (!send_tensor->numel()) {
send_tensor->Resize(kDims);
send_tensor->mutable_data<T>(kDims, place);
std::vector<T> send_vector(f::product(kDims), gpu_id);
send_tensor->CopyFromVector<T>(send_vector, *ctx);
ctx->Wait();
VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel();
}
lk.unlock();
PADDLE_ENFORCE(send_tensor->numel() == f::product(kDims),
"Tensor numel not match!");
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type();
VLOG(1) << " send_tensor : " << send_tensor->numel()
<< " recv_tensor : " << recv_tensor->numel();
op->Run(*scope, *ctx);
VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type();
}
public:
std::vector<p::DeviceContext *> dev_ctxs;
p::DeviceContext *cpu_ctx;
f::Scope g_scope;
std::mutex mu;
};
// ncclInitOp with desc
TEST(NCCL, ncclInitOp) {
std::unique_ptr<f::OpDescBind> op_desc(new f::OpDescBind);
op_desc->SetType("ncclInit");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
std::unique_ptr<p::DeviceContext> ctx(new p::CPUDeviceContext(p::CPUPlace()));
auto *var = g_scope.Var("x1");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx.get());
VLOG(1) << "NCCLInitOp finished.";
}
// ncclAllReduceOp with desc
TEST_F(NCCLTester, ncclAllReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
op2->SetType("ncclAllReduce");
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
std::vector<f::Scope *> dev_scopes;
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
ths[i].join();
}
// check results
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
for (size_t i = 0; i < dev_scopes.size(); ++i) {
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[i]);
auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
auto *result_tensor = dev_scopes[i]->Var("ct")->GetMutable<f::LoDTensor>();
result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[i]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[i])->stream());
for (size_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
}
// ncclReduceOp with desc
TEST_F(NCCLTester, ncclReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 0;
op2->SetType("ncclReduce");
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes;
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
ths[i].join();
}
// check results on
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[kRoot]);
auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
auto *result_tensor =
dev_scopes[kRoot]->Var("ct")->GetMutable<f::LoDTensor>();
result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[kRoot]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[kRoot])->stream());
for (int j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
// ncclBcastOp with desc
TEST_F(NCCLTester, ncclBcastOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 5;
op2->SetType("ncclBcast");
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes;
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
ths[i].join();
}
const int idx = 1;
// check results on
float result = kRoot;
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[idx]);
auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
auto *result_tensor = dev_scopes[idx]->Var("ct")->GetMutable<f::LoDTensor>();
result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[idx]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[idx])->stream());
for (size_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
int main(int argc, char **argv) {
const int dev_count = p::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
for (int i = 0; i < dev_count; ++i) {
gpu_list.emplace_back(i);
}
testing::InitGoogleTest(&argc, argv);
// device context should be release before scope.
// otherwise driver will down.
return RUN_ALL_TESTS();
}
...@@ -43,6 +43,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> { ...@@ -43,6 +43,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("globalPooling")) { if (ctx.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(input->dims()[i + 2]); ksize[i] = static_cast<int>(input->dims()[i + 2]);
} }
} }
...@@ -97,8 +98,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -97,8 +98,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("globalPooling")) { if (ctx.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(input->dims()[i + 2]); ksize[i] = static_cast<int>(input->dims()[i + 2]);
}
} }
const T *input_data = input->data<T>(); const T *input_data = input->data<T>();
......
...@@ -39,8 +39,10 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -39,8 +39,10 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
if (ctx->Attrs().Get<bool>("globalPooling")) { if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2); ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
} }
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
...@@ -84,15 +86,16 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -84,15 +86,16 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector ), the pooling window size(height, width) "
"(vector ), the pooling window size(height, width) of pooling operator." "of pooling operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings will "
"specified."); // TODO(Chengduo): Add checker. (Currently, "be ignored."); // TODO(Chengduo): Add checker.
// (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling." "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.") "If globalPooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -101,7 +104,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -101,7 +104,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector defalut:{0,0}), paddings(height, width) of pooling operator.") "(vector defalut:{0,0}), paddings(height, width) of pooling operator."
"If globalPooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -145,25 +149,28 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, ...@@ -145,25 +149,28 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector ), the pooling window size(depth, height, "
"(vector ), the pooling window size(depth, height, width) of pooling " "width) of pooling "
"operator." "operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings wille "
"specified."); // TODO(Chengduo): Add checker. (Currently, "be ignored."); // TODO(Chengduo): Add checker.
// TypedAttrChecker don't support vector type.) // (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling." "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.") "If globalPooling = true, ksize and paddings wille be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, height, " "(vector, default:{1,1,1}), strides(depth, height, "
"width) of pooling operator.") "width) of pooling operator.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>(
"(vector defalut:{0,0,0}), paddings(depth, height, " "paddings",
"width) of pooling operator.") "(vector defalut:{0,0,0}), paddings(depth, height, "
"width) of pooling operator."
"If globalPooling = true, ksize and paddings wille be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
......
...@@ -63,6 +63,7 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -63,6 +63,7 @@ class PoolKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
} }
...@@ -103,6 +104,7 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -103,6 +104,7 @@ class PoolKernel : public framework::OpKernel<T> {
paddings, pool_process); paddings, pool_process);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
}; };
...@@ -123,8 +125,10 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -123,8 +125,10 @@ class PoolGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
} }
if (in_x_grad) { if (in_x_grad) {
...@@ -164,6 +168,7 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -164,6 +168,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
*out_grad, ksize, strides, paddings, pool_process); *out_grad, ksize, strides, paddings, pool_process);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
} }
......
...@@ -46,8 +46,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("globalPooling")) { if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2); ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
} }
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
...@@ -87,31 +89,33 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,31 +89,33 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of pooling operator. " "(Tensor), the input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image."); "number of channels, H and W is the height and width of image.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of pooling operator." "(Tensor), the output tensor of pooling operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of image."); "width of image.");
AddOutput("Mask", AddOutput("Mask",
"(Tensor) The Mask tensor of pooling operator." "(Tensor), the Mask tensor of pooling operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W " "Where N is batch size, C is the number of channels, H and W "
"is the height and width of image." "is the height and width of image."
"The value in it is the index in current feature map"); "The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector ), the pooling window size(height, "
"(vector ), the pooling window size(height, width) of pooling operator." "width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings "
"specified."); // TODO(Chengduo): Add checker. (Currently, "will be ignored."); // TODO(Chengduo): Add
// checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>(
"(bool default: false), whether to use the global pooling." "globalPooling",
"If globalPooling = true, ksize is ignored.") "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -120,7 +124,8 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,7 +124,8 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector defalut:{0,0}), paddings(height, width) of pooling operator.") "(vector defalut:{0, 0}), paddings(height, width) of pooling operator."
"If globalPooling = true, paddings and will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -153,42 +158,46 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -153,42 +158,46 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of pooling operator. " "(Tensor), the input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is " "The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of " "the number of channels, D, H and W is the depth, height and width of "
"image."); "image.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of pooling operator." "(Tensor), the output tensor of pooling operator."
"The format of output tensor is also NCDHW." "The format of output tensor is also NCDHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and " "the number of channels, D, H and W is the depth, height and "
"width of image."); "width of image.");
AddOutput("Mask", AddOutput("Mask",
"(Tensor) The Mask tensor of pooling operator." "(Tensor), the Mask tensor of pooling operator."
"The format of output tensor is also NCDHW." "The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W " "Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image." "is the depth, height and width of image."
"The value in it is the index in current feature map"); "The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector), the pooling window size(depth, "
"(vector ), the pooling window size(depth, height, width) of pooling " "height, width) of pooling "
"operator." "operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings "
"specified."); // TODO(Chengduo): Add checker. (Currently, "will be ignored."); // TODO(Chengduo): Add
// checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>(
"(bool default: false), whether to use the global pooling." "globalPooling",
"If globalPooling = true, ksize is ignored.") "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, " "(vector, default:{1,1,1}), strides(depth, "
"height, width) of pooling operator.") "height, width) of pooling operator.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>(
"(vector defalut:{0,0,0}), paddings(depth, " "paddings",
"height, width) of pooling operator.") "(vector defalut:{0,0,0}), paddings(depth, "
"height, width) of pooling operator."
"If globalPooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
......
...@@ -37,6 +37,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -37,6 +37,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
} }
...@@ -54,6 +55,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -54,6 +55,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings); strides, paddings);
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
}; };
...@@ -72,6 +74,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -72,6 +74,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]); ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
} }
} }
...@@ -95,6 +98,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -95,6 +98,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
pool3d_backward(context.device_context(), *in_x_grad, *out_grad, pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings); *mask, ksize, strides, paddings);
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
} }
......
...@@ -34,13 +34,19 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -34,13 +34,19 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
for (auto dim : shape) { auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); // TODO(qiao) change batch_size
for (int i = 1; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0,
"Each dimension of shape "
"must be positiv except the first.");
}
if (shape[0] < 0) {
shape[0] = x_dims[0];
} }
// capacity check // capacity check
int64_t capacity = int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto x_dims = ctx->GetInputDim("X");
int64_t in_size = framework::product(x_dims); int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size, PADDLE_ENFORCE_EQ(capacity, in_size,
"The size of Input(X) mismatches with Attr(shape)."); "The size of Input(X) mismatches with Attr(shape).");
......
...@@ -26,13 +26,8 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -26,13 +26,8 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims();
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto shape = ctx.Attr<std::vector<int>>("shape");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context()); out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context());
out->Resize(out_dims); out->Resize(out_dims);
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/seq_expand_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class SeqExpandOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE(ctx->HasInput("Y"));
framework::DDim out_dim;
out_dim = ctx->GetInputDim("Y");
ctx->ShareLoD("Y", "Out");
ctx->SetOutputDim("Out", out_dim);
}
};
class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SeqExpandOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a "
"LoDTensor or a base Tensor.");
AddInput("Y",
"(LoDTensor)The reference input(Y) of seq_expand op."
"It must be a LoDTensor with k-level(k>0)."
"The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X).");
AddOutput("Out",
"(LodTensor)The output of seq_expand op."
"The lod of output will be as same as input(Y)'s lod.");
AddComment(R"DOC(
Expand input(X) according to LOD of input(Y).
Case 1:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
Out.data = [a, a, a, b, b, b, c, d]
Out.dims = [8, 1]
Case 2:
Given a 0-level LoDTensor input(X)
X.data = [a, b, c]
X.lod = NULL
X.dims = [3, 1]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [a, a, b, c, c, c]
Out.dims = [6, 1]
Case 3:
Given a 0-level LoDTensor input(X)
X.data = [[a, b], [c, d], [e, f]]
X.lod = NULL
X.dims = [3, 2]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
Out.dims = [6, 2]
Case 4:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
Out.data = [a, a, a, b, b, b, d, d]
Out.dims = [8, 1]
)DOC");
}
};
class SeqExpandOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Out"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP_CPU_KERNEL(seq_expand,
ops::SeqExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
seq_expand_grad,
ops::SeqExpandGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/seq_expand_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(seq_expand,
ops::SeqExpandKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
seq_expand_grad,
ops::SeqExpandGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class SeqExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE_EQ(x_dims[0], y->lod().back().size() - 1,
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
auto place = context.GetEigenDevice<Place>();
size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
for (size_t i = 0; i < out_starts.size() - 1; i++) {
int scale = out_starts[i + 1] - out_starts[i];
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
x_t(x_data, 1, element_len);
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
out_t(out_data, scale, element_len);
Eigen::array<int, 2> cast({scale, 1});
out_t.device(place) = x_t.broadcast(cast);
x_data += element_len;
out_data += element_len * scale;
}
}
};
/*
*Given Grad(Out)
*
* Grad(Out).lod = [[0, 2],
* [0, 3, 6]]
* Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
* Then
* Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)]
* = [0.6, 1.5]
* Grad(X).lod = Input(X).lod
*
* */
template <typename Place, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto out_last_level = out->lod().back();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = d_out->numel() / d_out->dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
size_t repeat = out_last_level[i + 1] - out_last_level[i];
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
d_out_t(d_out_data, static_cast<int>(repeat), element_len);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>
d_x_t(d_x_data, static_cast<int>(element_len));
auto place = context.GetEigenDevice<Place>();
d_x_t.device(place) = d_out_t.sum(Eigen::array<int, 1>({{0}}));
d_out_data += (repeat * element_len);
d_x_data += element_len;
}
}
};
} // namespace operators
} // namespace paddle
...@@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
"The level should be less than the level number of inputs.") "The level should be less than the level number of inputs.")
.SetDefault(0); .SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
The sequence_concat operator concatenates multiple LoDTensors. The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequence (LoD Tensor with level number is 1) It only supports sequence (LoD Tensor with level number is 1)
or a nested sequence (LoD tensor with level number is 2) as its input. or a nested sequence (LoD tensor with level number is 2) as its input.
- Case1: - Case1:
If the axis is other than 0(here, axis is 1 and level is 1), If the axis is other than 0(here, axis is 1 and level is 1),
each input should have the same LoD information and the LoD each input should have the same LoD information and the LoD
information of the output keeps the same as the input. information of the output keeps the same as the input.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
...@@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4)
- Case2: - Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute. time steps, the LoD information of the output need to re-compute.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
...@@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4)
LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4)
NOTE: The levels of all the inputs should be the same. NOTE: The levels of all the inputs should be the same.
)DOC"); )DOC");
} }
......
...@@ -144,11 +144,11 @@ class SequencePoolGradKernel : public framework::OpKernel<T> { ...@@ -144,11 +144,11 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
in_t_map(in_t.data<T>(), h, w); in_t_map(in_t.data<T>(), h, w);
int row_id; int row_id;
Eigen::array<int, 2> extents = {1, 1}; Eigen::array<int, 2> extents{{1, 1}};
for (int col_id = 0; col_id < w; col_id++) { for (int col_id = 0; col_id < w; col_id++) {
in_t_map.col(col_id).maxCoeff(&row_id); in_t_map.col(col_id).maxCoeff(&row_id);
Eigen::array<int, 2> in_offsets = {row_id, col_id}; Eigen::array<int, 2> in_offsets{{row_id, col_id}};
Eigen::array<int, 2> out_offsets = {0, col_id}; Eigen::array<int, 2> out_offsets{{0, col_id}};
in_g_e.slice(in_offsets, extents).device(place) = in_g_e.slice(in_offsets, extents).device(place) =
out_g_e.slice(out_offsets, extents); out_g_e.slice(out_offsets, extents);
} }
......
...@@ -89,11 +89,12 @@ struct SparseSGDFunctor<platform::CPUPlace, T> { ...@@ -89,11 +89,12 @@ struct SparseSGDFunctor<platform::CPUPlace, T> {
}; };
template struct SparseSGDFunctor<platform::CPUPlace, float>; template struct SparseSGDFunctor<platform::CPUPlace, float>;
template struct SparseSGDFunctor<platform::CPUPlace, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<paddle::platform::CPUPlace, float>,
ops::SGDOpKernel<paddle::platform::CPUPlace, float>); ops::SGDOpKernel<paddle::platform::CPUPlace, double>);
...@@ -71,10 +71,11 @@ struct SparseSGDFunctor<platform::GPUPlace, T> { ...@@ -71,10 +71,11 @@ struct SparseSGDFunctor<platform::GPUPlace, T> {
}; };
template struct SparseSGDFunctor<platform::GPUPlace, float>; template struct SparseSGDFunctor<platform::GPUPlace, float>;
template struct SparseSGDFunctor<platform::GPUPlace, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sgd, REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<paddle::platform::GPUPlace, float>,
ops::SGDOpKernel<paddle::platform::GPUPlace, float>); ops::SGDOpKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sign_op.h"
namespace paddle {
namespace operators {
class SignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SignOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SignOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator.");
AddComment(R"DOC(Sign operator
The equation is: Out = X.sign()
)DOC");
}
};
class SignGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("scale");
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttr("scale", 0.0f);
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker);
REGISTER_OP_CPU_KERNEL(sign,
ops::SignKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sign_op.h"
REGISTER_OP_GPU_KERNEL(
sign, paddle::operators::SignKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SignKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& place = context.GetEigenDevice<Place>();
eigen_out.device(place) = eigen_in.sign();
}
};
} // namespace operators
} // namespace paddle
...@@ -35,13 +35,6 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -35,13 +35,6 @@ class SumKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
// Runtime InferShape
for (int i = 0; i < N; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
out->Resize(in_vars[i]->Get<framework::LoDTensor>().dims());
break;
}
}
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
...@@ -73,12 +66,10 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -73,12 +66,10 @@ class SumKernel : public framework::OpKernel<T> {
first_dim += in_vars[i]->Get<SelectedRows>().rows().size(); first_dim += in_vars[i]->Get<SelectedRows>().rows().size();
} }
auto in_dim = in_vars[0]->Get<SelectedRows>().value().dims(); auto in_dim = in_vars[0]->Get<SelectedRows>().value().dims();
auto in_dim_vec = framework::vectorize(in_dim); auto in_dim_vec = framework::vectorize(in_dim);
in_dim_vec[0] = static_cast<int64_t>(first_dim); in_dim_vec[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim_vec)); out_value->Resize(framework::make_ddim(in_dim_vec));
out_value->mutable_data<T>(context.GetPlace()); out_value->mutable_data<T>(context.GetPlace());
math::SelectedRowsAddTo<Place, T> functor; math::SelectedRowsAddTo<Place, T> functor;
......
...@@ -95,4 +95,5 @@ Used to initialize tensor with uniform random generator. ...@@ -95,4 +95,5 @@ Used to initialize tensor with uniform random generator.
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>); paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
...@@ -64,4 +64,5 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -64,4 +64,5 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(uniform_random, REGISTER_OP_GPU_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>); paddle::operators::GPUUniformRandomKernel<float>,
paddle::operators::GPUUniformRandomKernel<double>);
...@@ -31,9 +31,7 @@ namespace platform { ...@@ -31,9 +31,7 @@ namespace platform {
TEST(NCCL, init) { TEST(NCCL, init) {
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
comms.resize(dev_count); comms.resize(dev_count);
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
for (int i = 0; i < dev_count; ++i) { for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]); dynload::ncclCommDestroy(comms[i]);
} }
...@@ -64,8 +62,7 @@ TEST(NCCL, all_reduce) { ...@@ -64,8 +62,7 @@ TEST(NCCL, all_reduce) {
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
comms.resize(dev_count); comms.resize(dev_count);
VLOG(1) << "Initializing ncclComm"; VLOG(1) << "Initializing ncclComm";
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
PADDLE_ENFORCE(status);
VLOG(1) << "ncclComm initialized"; VLOG(1) << "ncclComm initialized";
VLOG(1) << "Creating thread data"; VLOG(1) << "Creating thread data";
std::vector<std::unique_ptr<PerThreadData<double>>> data; std::vector<std::unique_ptr<PerThreadData<double>>> data;
......
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc SRCS pybind.cc exception.cc protobuf.cc
DEPS pybind python backward proto_desc tensor_array paddle_memory executor DEPS pybind python backward proto_desc tensor_array paddle_memory executor prune
${GLOB_OP_LIB}) ${GLOB_OP_LIB})
endif(WITH_PYTHON) endif(WITH_PYTHON)
......
...@@ -141,6 +141,13 @@ void BindProgramDesc(py::module &m) { ...@@ -141,6 +141,13 @@ void BindProgramDesc(py::module &m) {
desc->SerializeToString(&res), desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle."); "Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res; return res;
})
.def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle.");
}); });
} }
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/prune.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h" #include "paddle/framework/tensor_array.h"
#include "paddle/operators/cond_op.h" #include "paddle/operators/cond_op.h"
...@@ -32,6 +33,11 @@ limitations under the License. */ ...@@ -32,6 +33,11 @@ limitations under the License. */
#include "paddle/pybind/tensor_py.h" #include "paddle/pybind/tensor_py.h"
#include "paddle/string/to_string.h" #include "paddle/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
#endif
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
static size_t UniqueIntegerGenerator() { static size_t UniqueIntegerGenerator() {
...@@ -203,6 +209,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -203,6 +209,13 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetMutable<SelectedRows>(); return self.GetMutable<SelectedRows>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
#ifdef PADDLE_WITH_CUDA
.def("get_communicator",
[](Variable &self) -> platform::Communicator * {
return self.GetMutable<platform::Communicator>();
},
py::return_value_policy::reference)
#endif
.def("get_net", .def("get_net",
[](Variable &self) -> operators::NetOp * { [](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>(); return self.GetMutable<operators::NetOp>();
...@@ -237,6 +250,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -237,6 +250,16 @@ All parameter, weight, gradient are variables in Paddle.
} }
return ret_values; return ret_values;
}); });
m.def("prune", [](const ProgramDescBind &origin,
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget();
}
ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
...@@ -258,8 +281,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -258,8 +281,11 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CUDADeviceContext(place); return new paddle::platform::CUDADeviceContext(place);
#endif #endif
}); });
// clang-format on // clang-format on
#ifdef PADDLE_WITH_CUDA
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
py::class_<platform::GPUPlace>(m, "GPUPlace") py::class_<platform::GPUPlace>(m, "GPUPlace")
.def(py::init<int>()) .def(py::init<int>())
.def("__str__", string::to_string<const platform::GPUPlace &>); .def("__str__", string::to_string<const platform::GPUPlace &>);
...@@ -468,6 +494,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -468,6 +494,9 @@ All parameter, weight, gradient are variables in Paddle.
BindOpDesc(m); BindOpDesc(m);
m.def("op_support_gpu", OpSupportGPU); m.def("op_support_gpu", OpSupportGPU);
#ifdef PADDLE_WITH_CUDA
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
#endif
return m.ptr(); return m.ptr();
} }
......
...@@ -85,7 +85,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -85,7 +85,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
} // namespace details } // namespace details
inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
auto buffer_info = auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double>()(tensor); details::CastToPyBufferImpl<true, 0, float, int, double, int64_t>()(
tensor);
return buffer_info; return buffer_info;
} }
......
...@@ -251,6 +251,8 @@ class Operator(object): ...@@ -251,6 +251,8 @@ class Operator(object):
self.desc.set_output(out_proto.name, out_argu_names) self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None: if attrs is not None:
if not isinstance(attrs, dict):
raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None): if (not attr_name in attrs) or (attrs[attr_name] is None):
...@@ -291,6 +293,14 @@ class Operator(object): ...@@ -291,6 +293,14 @@ class Operator(object):
def output_names(self): def output_names(self):
return self.desc.output_names() return self.desc.output_names()
@property
def idx(self):
for i, op in enumerate(self.block.ops):
if op == self:
return i
raise ValueError(
"Can't find op itself in it's block. It could be a bug of Paddle.")
def has_attr(self, name): def has_attr(self, name):
return self.desc.has_attr(name) return self.desc.has_attr(name)
...@@ -342,7 +352,10 @@ class Block(object): ...@@ -342,7 +352,10 @@ class Block(object):
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)} return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)}
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs) var = Variable(self, *args, **kwargs)
if 'init_attr' in kwargs:
self._prepend_initialize_ops_(var, kwargs['init_attr'])
return var
def has_var(self, name): def has_var(self, name):
return name in self.vars return name in self.vars
...@@ -440,10 +453,31 @@ class Program(object): ...@@ -440,10 +453,31 @@ class Program(object):
p.sync_with_cpp() p.sync_with_cpp()
return p return p
def prune(self, targets):
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
t = t.op
else:
raise ValueError(
"All targets of prune() can only be Variable or Operator."
)
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, targets_idx)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp()
return res
@staticmethod @staticmethod
def parse_from_string(binary_str): def parse_from_string(binary_str):
p = Program() p = Program()
p.desc = core.ProgramDesc(binary_str) p.desc = core.ProgramDesc(binary_str)
p.blocks = [Block(p, i) for i in xrange(p.desc.num_blocks())]
p.sync_with_cpp() p.sync_with_cpp()
return p return p
......
import os import os
import cPickle as pickle
from paddle.v2.framework.framework import Program, Parameter, g_program, \ from paddle.v2.framework.framework import Program, Parameter, g_program, \
Variable Variable
__all__ = [ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables' 'load_persistables', "save_inference_model", "load_inference_model"
] ]
...@@ -31,7 +32,7 @@ def _clone_var_in_block_(block, var): ...@@ -31,7 +32,7 @@ def _clone_var_in_block_(block, var):
def save_vars(executor, dirname, program=None, vars=None, predicate=None): def save_vars(executor, dirname, program=None, vars=None, predicate=None):
""" """
Save variables to directory by executor. Save variables to directory by executor.
:param executor: executor that save variable :param executor: executor that save variable
:param dirname: directory path :param dirname: directory path
:param program: program. If vars is None, then filter all variables in this :param program: program. If vars is None, then filter all variables in this
...@@ -92,7 +93,7 @@ def save_persistables(executor, dirname, program=None): ...@@ -92,7 +93,7 @@ def save_persistables(executor, dirname, program=None):
def load_vars(executor, dirname, program=None, vars=None, predicate=None): def load_vars(executor, dirname, program=None, vars=None, predicate=None):
""" """
Load variables from directory by executor. Load variables from directory by executor.
:param executor: executor that save variable :param executor: executor that save variable
:param dirname: directory path :param dirname: directory path
:param program: program. If vars is None, then filter all variables in this :param program: program. If vars is None, then filter all variables in this
...@@ -124,6 +125,7 @@ def load_vars(executor, dirname, program=None, vars=None, predicate=None): ...@@ -124,6 +125,7 @@ def load_vars(executor, dirname, program=None, vars=None, predicate=None):
inputs={}, inputs={},
outputs={"Out": [new_var]}, outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(load_prog) executor.run(load_prog)
...@@ -141,3 +143,88 @@ def load_persistables(executor, dirname, program=None): ...@@ -141,3 +143,88 @@ def load_persistables(executor, dirname, program=None):
""" """
load_vars( load_vars(
executor, dirname=dirname, program=program, predicate=is_persistable) executor, dirname=dirname, program=program, predicate=is_persistable)
def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
program=None):
"""
Build a model especially for inference,
and save it to directory by the executor.
:param dirname: directory path
:param feeded_var_names: Names of variables that need to be feeded data during inference
:param target_vars: Variables from which we can get inference results.
:param executor: executor that save inference model
:param program: original program, which will be pruned to build the inference model.
Default g_program.
:return: None
"""
if program is None:
program = g_program
if not isinstance(target_vars, list):
target_vars = [target_vars]
if not os.path.isdir(dirname):
os.makedirs(dirname)
pruned_program = program.prune(target_vars)
fetch_var_names = [v.name for v in target_vars]
model_file_name = dirname + "/__model__"
with open(model_file_name, "w") as f:
pickle.dump({
"program_desc_str": pruned_program.desc.serialize_to_string(),
"feed_var_names": feeded_var_names,
"fetch_var_names": fetch_var_names
}, f, -1)
save_params(executor, dirname, program)
def load_persistables_if_exist(executor, dirname, program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)
def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames
load_vars(
executor,
dirname,
program=program,
vars=None,
predicate=_is_presistable_and_exist_)
def load_inference_model(dirname, executor):
"""
Load inference model from a directory
:param dirname: directory path
:param executor: executor that load inference model
:return: [program, feed_var_names, fetch_var_names]
program: program especially for inference.
feeded_var_names: Names of variables that need to feed data
fetch_vars: Variables from which we can get inference results.
"""
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
model_file_name = dirname + "/__model__"
model = pickle.load(open(model_file_name, "r"))
program_desc_str = model["program_desc_str"]
feed_var_names = model["feed_var_names"]
fetch_var_names = model["fetch_var_names"]
program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
fetch_vars = [program.global_block().var(name) for name in fetch_var_names]
return [program, feed_var_names, fetch_vars]
...@@ -131,12 +131,14 @@ class LayerHelper(object): ...@@ -131,12 +131,14 @@ class LayerHelper(object):
return dtype return dtype
def create_parameter(self, attr, shape, dtype, suffix='w'): def create_parameter(self, attr, shape, dtype, suffix='w'):
if attr['name'] is None: # Deepcopy the attr so that parameters can be shared in program
attr['name'] = unique_name(".".join([self.name, suffix])) attr_copy = copy.deepcopy(attr)
if attr_copy['name'] is None:
attr_copy['name'] = unique_name(".".join([self.name, suffix]))
self.init_program.global_block().create_parameter( self.init_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr) dtype=dtype, shape=shape, **attr_copy)
return self.program.global_block().create_parameter( return self.program.global_block().create_parameter(
name=attr['name'], dtype=dtype, shape=shape) name=attr_copy['name'], dtype=dtype, shape=shape)
def create_tmp_variable(self, dtype): def create_tmp_variable(self, dtype):
return self.program.current_block().create_var( return self.program.current_block().create_var(
......
...@@ -5,7 +5,7 @@ import re ...@@ -5,7 +5,7 @@ import re
__all__ = [ __all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN' 'StaticRNN', 'cast'
] ]
...@@ -61,6 +61,7 @@ def fc(input, ...@@ -61,6 +61,7 @@ def fc(input,
def embedding(input, def embedding(input,
size, size,
data_type='float32', data_type='float32',
is_sparse=False,
param_attr=None, param_attr=None,
program=None, program=None,
init_program=None): init_program=None):
...@@ -72,7 +73,8 @@ def embedding(input, ...@@ -72,7 +73,8 @@ def embedding(input,
type='lookup_table', type='lookup_table',
inputs={'Ids': input, inputs={'Ids': input,
'W': w}, 'W': w},
outputs={'Out': tmp}) outputs={'Out': tmp},
attrs={'is_sparse': is_sparse})
return tmp return tmp
...@@ -159,6 +161,19 @@ def _create_op_func_(op_type): ...@@ -159,6 +161,19 @@ def _create_op_func_(op_type):
_create_op_func_('mean') _create_op_func_('mean')
_create_op_func_('mul') _create_op_func_('mul')
_create_op_func_('dropout') _create_op_func_('dropout')
_create_op_func_('reshape')
def cast(x, data_type, program=None):
helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=data_type)
helper.append_op(
type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'in_data_type': x.data_type,
'out_data_type': out.data_type})
return out
def concat(input, axis, program=None, init_program=None): def concat(input, axis, program=None, init_program=None):
...@@ -294,6 +309,96 @@ def pool2d(input, ...@@ -294,6 +309,96 @@ def pool2d(input,
return pool_out return pool_out
def batch_norm(input,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
program=None,
init_program=None):
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
if data_layout == 'NHWC':
channel_num = input_shape[-1]
else:
raise ValueError("unsupported data layout:" + data_layout)
def get_init_attr(value):
if not isinstance(value, float):
raise ValueError("attr value should be a float")
return {'type': 'fill_constant', 'value': value}
def prepend_init_op(var, init_attr):
assert isinstance(var, Variable)
op_type = init_attr['type']
init_attr['shape'] = var.shape
init_attr['data_type'] = int(var.data_type)
op = var.block.prepend_op(
type=op_type, inputs=None, outputs={'Out': [var]}, attrs=init_attr)
return op
def create_persistable_var(dtype, shape, init_attr=None):
name = unique_name(".".join([helper.name, "xxxx"]))
var = init_program.global_block().create_var(
dtype=dtype, shape=shape, name=name, persistable=True)
if 'init_attr' is not None:
prepend_init_op(var, init_attr)
return program.global_block().create_var(
name=name, dtype=dtype, shape=shape, persistable=True)
param_shape = [channel_num]
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr, shape=param_shape, dtype=dtype)
bias = helper.create_parameter(
attr=helper.param_attr, shape=param_shape, dtype=dtype)
# create input
mean = create_persistable_var(dtype, param_shape, get_init_attr(0.0))
variance = create_persistable_var(dtype, param_shape, get_init_attr(1.0))
# create output
# mean and mean_out share the same memory
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
saved_mean = helper.create_tmp_variable(dtype)
saved_variance = helper.create_tmp_variable(dtype)
batch_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="batch_norm",
inputs={
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test})
return helper.append_activation(batch_norm_out)
class BlockGuard(object): class BlockGuard(object):
""" """
BlockGuard used to create sub-block in program by using Python `with` BlockGuard used to create sub-block in program by using Python `with`
......
...@@ -7,6 +7,7 @@ def simple_img_conv_pool(input, ...@@ -7,6 +7,7 @@ def simple_img_conv_pool(input,
pool_size, pool_size,
pool_stride, pool_stride,
act, act,
pool_type='max',
program=None, program=None,
init_program=None): init_program=None):
conv_out = layers.conv2d( conv_out = layers.conv2d(
...@@ -20,7 +21,75 @@ def simple_img_conv_pool(input, ...@@ -20,7 +21,75 @@ def simple_img_conv_pool(input,
pool_out = layers.pool2d( pool_out = layers.pool2d(
input=conv_out, input=conv_out,
pool_size=pool_size, pool_size=pool_size,
pool_type='max', pool_type=pool_type,
pool_stride=pool_stride,
program=program,
init_program=init_program)
return pool_out
def img_conv_group(input,
conv_num_filter,
pool_size,
conv_padding=1,
conv_filter_size=3,
conv_act=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None,
pool_stride=1,
pool_type=None,
program=None,
init_program=None):
"""
Image Convolution Group, Used for vgg net.
"""
tmp = input
assert isinstance(conv_num_filter, list) or \
isinstance(conv_num_filter, tuple)
def __extend_list__(obj):
if not hasattr(obj, '__len__'):
return [obj] * len(conv_num_filter)
else:
return obj
conv_padding = __extend_list__(conv_padding)
conv_filter_size = __extend_list__(conv_filter_size)
conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
for i in xrange(len(conv_num_filter)):
local_conv_act = conv_act
if conv_with_batchnorm[i]:
local_conv_act = None
tmp = layers.conv2d(
input=tmp,
num_filters=conv_num_filter[i],
filter_size=conv_filter_size[i],
padding=conv_padding[i],
act=local_conv_act,
program=program,
init_program=init_program)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(
input=tmp,
act=conv_act,
program=program,
init_program=init_program)
drop_rate = conv_batchnorm_drop_rate[i]
if abs(drop_rate) > 1e-5:
tmp = layers.dropout(
x=tmp,
dropout_prob=drop_rate,
program=program,
init_program=init_program)
pool_out = layers.pool2d(
input=tmp,
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride, pool_stride=pool_stride,
program=program, program=program,
init_program=init_program) init_program=init_program)
......
...@@ -18,7 +18,8 @@ class Optimizer(object): ...@@ -18,7 +18,8 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
def __init__(self): def __init__(self, global_step=None):
self._global_step = global_step
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra variables associated with the parameters # allocate and manage extra variables associated with the parameters
# to train. These variables are called accumulators. # to train. These variables are called accumulators.
...@@ -109,6 +110,26 @@ class Optimizer(object): ...@@ -109,6 +110,26 @@ class Optimizer(object):
format(name, param.name)) format(name, param.name))
return self._accumulators[name][param.name] return self._accumulators[name][param.name]
def _increment_global_step(self, block):
"""Increment the global step by 1 after every iteration
Args:
block: the block in which the loss variable is present
Returns:
list with global_step increment op as its only element
"""
assert isinstance(block, framework.Block)
assert self._global_step is not None
# create the increment op
increment_op = block.append_op(
type="increment",
inputs={"X": self._global_step},
outputs={"Out": self._global_step},
attrs={"step": 1.0})
return increment_op
def create_optimization_pass(self, parameters_and_grads, loss): def create_optimization_pass(self, parameters_and_grads, loss):
"""Add optimization operators to update gradients to variables. """Add optimization operators to update gradients to variables.
...@@ -152,6 +173,8 @@ class Optimizer(object): ...@@ -152,6 +173,8 @@ class Optimizer(object):
if finish_ops is not None: if finish_ops is not None:
return_ops += finish_ops return_ops += finish_ops
if self._global_step is not None:
return_ops.append(self._increment_global_step(loss.block))
return return_ops return return_ops
def minimize(self, loss, parameter_list=None, no_grad_set=None): def minimize(self, loss, parameter_list=None, no_grad_set=None):
...@@ -172,9 +195,9 @@ class SGDOptimizer(Optimizer): ...@@ -172,9 +195,9 @@ class SGDOptimizer(Optimizer):
""" Simple SGD optimizer without any state. """ Simple SGD optimizer without any state.
""" """
def __init__(self, learning_rate): def __init__(self, learning_rate, global_step=None):
assert learning_rate is not None assert learning_rate is not None
super(SGDOptimizer, self).__init__() super(SGDOptimizer, self).__init__(global_step)
self.type = "sgd" self.type = "sgd"
self._learning_rate = learning_rate self._learning_rate = learning_rate
...@@ -215,10 +238,14 @@ class MomentumOptimizer(Optimizer): ...@@ -215,10 +238,14 @@ class MomentumOptimizer(Optimizer):
""" """
_velocity_acc_str = "velocity" _velocity_acc_str = "velocity"
def __init__(self, learning_rate, momentum, use_nesterov=False): def __init__(self,
learning_rate,
momentum,
use_nesterov=False,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
super(MomentumOptimizer, self).__init__() super(MomentumOptimizer, self).__init__(global_step)
self.type = "momentum" self.type = "momentum"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._momentum = momentum self._momentum = momentum
...@@ -275,10 +302,10 @@ class AdagradOptimizer(Optimizer): ...@@ -275,10 +302,10 @@ class AdagradOptimizer(Optimizer):
""" """
_moment_acc_str = "moment" _moment_acc_str = "moment"
def __init__(self, learning_rate, epsilon=1.0e-6): def __init__(self, learning_rate, epsilon=1.0e-6, global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert epsilon is not None assert epsilon is not None
super(AdagradOptimizer, self).__init__() super(AdagradOptimizer, self).__init__(global_step)
self.type = "adagrad" self.type = "adagrad"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._epsilon = epsilon self._epsilon = epsilon
...@@ -337,12 +364,13 @@ class AdamOptimizer(Optimizer): ...@@ -337,12 +364,13 @@ class AdamOptimizer(Optimizer):
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-8): epsilon=1e-8,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
super(AdamOptimizer, self).__init__() super(AdamOptimizer, self).__init__(global_step)
self.type = "adam" self.type = "adam"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._beta1 = beta1 self._beta1 = beta1
...@@ -458,7 +486,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -458,7 +486,8 @@ class AdamaxOptimizer(Optimizer):
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-8): epsilon=1e-8,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
......
import paddle.v2.framework.framework as framework import paddle.v2.framework.framework as framework
__all__ = ['append_regularization_ops', 'L2DecayRegularizer'] __all__ = [
'append_regularization_ops', 'L2DecayRegularizer', 'L1DecayRegularizer'
]
def append_regularization_ops(parameters_and_grads): def append_regularization_ops(parameters_and_grads):
...@@ -97,3 +99,43 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -97,3 +99,43 @@ class L2DecayRegularizer(WeightDecayRegularizer):
attrs={"scale": self._regularization_coeff}) attrs={"scale": self._regularization_coeff})
return decay return decay
class L1DecayRegularizer(WeightDecayRegularizer):
"""Implements the L1 Weight Decay Regularization
"""
def __init__(self, regularization_coeff=0.0):
assert regularization_coeff is not None
super(L1DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, block):
"""Add L1 weight decay ops to network
Adds L1 weight decay ops.
L1WeightDecay = reg_coeff * sign(parameter)
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level)
# Append sign op
block.append_op(
type='sign', inputs={"X": param}, outputs={"Out": decay})
# Append scale op to the output of sign op
block.append_op(
type='scale',
inputs={"X": decay},
outputs={"Out": decay},
attrs={"scale": self._regularization_coeff})
return decay
...@@ -21,16 +21,36 @@ def get_backward_op(scope, op, no_grad_set): ...@@ -21,16 +21,36 @@ def get_backward_op(scope, op, no_grad_set):
def _reference_training(x, scale, offset, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format):
if data_format != "NHWC": if data_format == "NCHW":
raise ValueError("data_format must be NHWC, got %s." % data_format) n, c, h, w = x.shape
x_square = x * x x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2)) x_square_sum = np.sum(x_square, (0, 2, 3))
x_sum = np.sum(x, axis=(0, 1, 2)) x_sum = np.sum(x, axis=(0, 2, 3))
element_count = np.size(x) / int(np.shape(x)[-1]) element_count = np.size(x) / int(np.shape(x)[1])
mean = x_sum / element_count mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon) mean_tile = np.reshape(mean, (1, c, 1, 1))
return (normalized * scale + offset), mean, var mean_tile = np.tile(mean_tile, (n, 1, h, w))
var_tile = np.reshape(var, (1, c, 1, 1))
var_tile = np.tile(var_tile, (n, 1, h, w))
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
return y, mean, var
elif data_format == "NHWC":
x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2))
element_count = np.size(x) / int(np.shape(x)[-1])
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon)
return (normalized * scale + offset), mean, var
else:
raise ValueError("Unknown data order.")
def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
...@@ -43,8 +63,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): ...@@ -43,8 +63,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
# grad_x = # grad_x =
# 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format) # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
grad_y = np.transpose(grad_y, (0, 2, 3, 1))
# raise ValueError("data_format must be NHWC, got %s." % data_format)
grad_x = scale * (grad_y - np.mean( grad_x = scale * (grad_y - np.mean(
grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean( grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean(
grad_y * (x - mean), axis=(0, 1, 2)) / grad_y * (x - mean), axis=(0, 1, 2)) /
...@@ -52,6 +77,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): ...@@ -52,6 +77,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon), grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon),
axis=(0, 1, 2)) axis=(0, 1, 2))
grad_offset = np.sum(grad_y, axis=(0, 1, 2)) grad_offset = np.sum(grad_y, axis=(0, 1, 2))
# transfer back to N, C, H, W
if data_format == "NCHW":
grad_x = np.transpose(grad_x, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
grad_y = np.transpose(grad_y, (0, 3, 1, 2))
return grad_x, grad_scale, grad_offset return grad_x, grad_scale, grad_offset
...@@ -65,61 +96,135 @@ def create_or_get_tensor(scope, var_name, var, place): ...@@ -65,61 +96,135 @@ def create_or_get_tensor(scope, var_name, var, place):
return tensor return tensor
def set_output_grad(scope, outputs, place): def set_output_grad(scope, outputs, place, feed_dict=None):
def __set_tensor__(name): def __set_tensor__(name, data=None):
out_tensor = scope.find_var(name).get_tensor() out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype() out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64: if data is None:
data = np.ones(out_tensor.shape(), dtype=np.float64) if out_dtype == core.DataType.FP64:
elif out_dtype == core.DataType.FP32: data = np.ones(out_tensor.shape(), dtype=np.float64)
data = np.ones(out_tensor.shape(), dtype=np.float32) elif out_dtype == core.DataType.FP32:
else: data = np.ones(out_tensor.shape(), dtype=np.float32)
raise ValueError("Not supported data type " + str(out_dtype)) else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place) grad_tensor.set(data, place)
for output in outputs: for output in outputs:
__set_tensor__(output) data = None
if output in feed_dict:
data = feed_dict[output]
__set_tensor__(output, data)
class TestBatchNormOp(OpTest): class TestBatchNormOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_forward_backward(self): def test_python(self):
# attr
data_format = "NHWC" data_format = "NHWC"
epsilon = 0.00001 epsilon = 0.00001
momentum = 0.9 momentum = 0.9
channel_num = 2 # N, H, W, C: 2, 3, 4, 2
x_shape = [2, 3, 4, channel_num] n, h, w, c = 2, 3, 4, 2
scale_shape = [channel_num] x_shape = [n, h, w, c]
scale_shape = [c]
# input
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32) mean = np.zeros(scale_shape).astype(np.float32)
variance = np.zeros(scale_shape).astype(np.float32) variance = np.ones(scale_shape).astype(np.float32)
# run forward # run forward
y_out, saved_mean, var_ref = _reference_training( y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, data_format) x_val, scale_val, bias_val, epsilon, "NHWC")
#
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# running N, C, H, W case
# should produce the same results
x_shape2 = [n, c, h, w]
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
y_out2, saved_mean2, var_ref2 = _reference_training(
x_val2, scale_val, bias_val, epsilon, "NCHW")
self.__assert_close(saved_mean, saved_mean2, "batch mean")
self.__assert_close(var_ref, var_ref2, "batch variance")
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "batch variance")
print 'python: NHWC, NCHW, forward checking passed'
# test backward now
# NHWC
self.y_grad = np.random.random_sample(x_shape).astype(np.float32)
y_grad = self.y_grad
# y_grad = np.ones(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC")
# run backward # NCHW
mean_out = saved_mean * (1 - momentum) y_grad2 = np.transpose(y_grad, (0, 3, 1, 2))
variance_out = var_ref * (1 - momentum) # y_grad2 = np.ones(x_shape2).astype(np.float32)
saved_variance = 1 / np.sqrt(var_ref + epsilon) x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad(
x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW")
# for gradient test self.__assert_close(scale_grad_ref, scale_grad_ref2, "scale gradient")
y_grad = np.ones(x_shape).astype(np.float32) self.__assert_close(bias_grad_ref, bias_grad_ref2, "bias gradient")
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) x_grad_transpose = np.transpose(x_grad_ref2, (0, 2, 3, 1))
self.__assert_close(x_grad_ref, x_grad_transpose, "x gradient")
print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self):
def test_with_place(place, tensor_format):
# attr
epsilon = 0.00001
momentum = 0.9
# N, H, W, C: 12, 3, 4, 2
n, h, w, c = 2, 3, 4, 2
if data_format == "NHWC":
x_shape = [n, h, w, c]
elif data_format == "NCHW":
x_shape = [n, c, h, w]
else:
raise ValueError("Unknown data type.")
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# run forward
y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, data_format)
# update moving mean and variance
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# for gradient test
# y_grad = np.ones(x_shape).astype(np.float32)
y_grad = np.zeros(x_shape).astype(np.float32)
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
data_format)
def test_with_place(place):
scope = core.Scope() scope = core.Scope()
# create input # create input
...@@ -157,7 +262,7 @@ class TestBatchNormOp(OpTest): ...@@ -157,7 +262,7 @@ class TestBatchNormOp(OpTest):
SavedVariance="saved_variance", SavedVariance="saved_variance",
# attrs # attrs
is_test=False, is_test=False,
tensor_format=data_format, tensor_format=tensor_format,
momentum=momentum, momentum=momentum,
epsilon=epsilon) epsilon=epsilon)
...@@ -170,20 +275,21 @@ class TestBatchNormOp(OpTest): ...@@ -170,20 +275,21 @@ class TestBatchNormOp(OpTest):
self.__assert_close(saved_variance_tensor, saved_variance, self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance") "saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out") self.__assert_close(mean_out_tensor, mean_out, "mean_out")
# FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate
if isinstance(place, core.GPUPlace): if isinstance(place, core.GPUPlace):
atol = 5e-2 atol = 5e-2
else: else:
atol = 1e-4 atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out, self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol) "variance_out", atol)
print "op test forward passed: ", str(place), tensor_format
# run backward # run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
set_output_grad( set_output_grad(
scope, scope,
["y_out", "mean", "variance", "saved_mean", "saved_variance"], ["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place) place,
feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx) batch_norm_op_grad.run(scope, ctx)
x_grad_tensor = create_or_get_tensor(scope, x_grad_tensor = create_or_get_tensor(scope,
...@@ -200,12 +306,14 @@ class TestBatchNormOp(OpTest): ...@@ -200,12 +306,14 @@ class TestBatchNormOp(OpTest):
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
print "op test backward passed: ", str(place), tensor_format
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0)) places.append(core.GPUPlace(0))
for place in places: for place in places:
test_with_place(place) for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format)
if __name__ == '__main__': if __name__ == '__main__':
......
import op_test
import unittest
import numpy as np
import paddle.v2.framework.core as core
class TestCastOp(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')}
self.attrs = {
'in_data_type': int(core.DataType.FP32),
'out_data_type': int(core.DataType.FP64)
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'])
if __name__ == '__main__':
unittest.main()
...@@ -14,7 +14,7 @@ class TestCrossEntropyOp1(OpTest): ...@@ -14,7 +14,7 @@ class TestCrossEntropyOp1(OpTest):
X = randomize_probability(batch_size, class_num, dtype='float64') X = randomize_probability(batch_size, class_num, dtype='float64')
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64")
cross_entropy = np.asmatrix( cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float64") dtype="float64")
...@@ -92,5 +92,4 @@ class TestCrossEntropyOp3(OpTest): ...@@ -92,5 +92,4 @@ class TestCrossEntropyOp3(OpTest):
if __name__ == "__main__": if __name__ == "__main__":
exit(0) # Gradient operator has bug!
unittest.main() unittest.main()
import unittest
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
from paddle.v2.framework.framework import Program
def conv_block(input,
num_filter,
groups,
dropouts,
program=None,
init_program=None):
return nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max',
program=program,
init_program=init_program)
class TestLayer(unittest.TestCase):
def test_batch_norm_layer(self):
program = Program()
init_program = Program()
images = layers.data(
name='pixel',
shape=[3, 48, 48],
data_type='float32',
program=program)
layers.batch_norm(
input=images, program=program, init_program=init_program)
#print str(program)
def test_dropout_layer(self):
program = Program()
init_program = Program()
images = layers.data(
name='pixel',
shape=[3, 48, 48],
data_type='float32',
program=program)
layers.dropout(
x=images,
dropout_prob=0.5,
program=program,
init_program=init_program)
#print str(program)
def test_img_conv_group(self):
program = Program()
init_program = Program()
images = layers.data(
name='pixel',
shape=[3, 48, 48],
data_type='float32',
program=program,
init_program=init_program)
conv1 = conv_block(images, 64, 2, [0.3, 0], program, init_program)
conv2 = conv_block(conv1, 256, 3, [0.4, 0.4, 0], program, init_program)
# print str(program)
if __name__ == '__main__':
unittest.main()
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.executor import Executor
import numpy as np
def vgg16_bn_drop(input, program, init_program):
def conv_block(input,
num_filter,
groups,
dropouts,
program=None,
init_program=None):
return nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max',
program=program,
init_program=init_program)
conv1 = conv_block(input, 64, 2, [0.3, 0], program, init_program)
conv2 = conv_block(conv1, 128, 2, [0.4, 0], program, init_program)
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0], program, init_program)
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0], program, init_program)
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0], program, init_program)
drop = layers.dropout(
x=conv5, dropout_prob=0.5, program=program, init_program=init_program)
fc1 = layers.fc(input=drop,
size=512,
act=None,
program=program,
init_program=init_program)
reshape1 = layers.reshape(
x=fc1,
shape=list(fc1.shape + (1, 1)),
program=program,
init_program=init_program)
bn = layers.batch_norm(
input=reshape1, act='relu', program=program, init_program=init_program)
drop2 = layers.dropout(
x=bn, dropout_prob=0.5, program=program, init_program=init_program)
fc2 = layers.fc(input=drop2,
size=512,
act=None,
program=program,
init_program=init_program)
return fc2
init_program = Program()
program = Program()
classdim = 10
data_shape = [3, 32, 32]
images = layers.data(
name='pixel', shape=data_shape, data_type='float32', program=program)
label = layers.data(
name='label',
shape=[1],
data_type='int64',
program=program,
init_program=init_program)
vgg_net = vgg16_bn_drop(images, program, init_program)
predict = layers.fc(input=vgg_net,
size=classdim,
act='softmax',
program=program,
init_program=init_program)
cost = layers.cross_entropy(
input=predict, label=label, program=program, init_program=init_program)
avg_cost = layers.mean(x=cost, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 128
PASS_NUM = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
place = core.CPUPlace()
exe = Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
for pass_id in range(PASS_NUM):
batch_id = 0
for data in train_reader():
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
batch_size = 1
for i in y_data.shape:
batch_size = batch_size * i
y_data = y_data.reshape([batch_size, 1])
tensor_img = core.LoDTensor()
tensor_y = core.LoDTensor()
tensor_img.set(img_data, place)
tensor_y.set(y_data, place)
outs = exe.run(program,
feed={"pixel": tensor_img,
"label": tensor_y},
fetch_list=[avg_cost])
loss = np.array(outs[0])
# print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) +
# " loss:" + str(loss))
batch_id = batch_id + 1
if batch_id > 1:
# this model is slow, so if we can train two mini batch, we think it works properly.
exit(0)
exit(1)
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.io import save_inference_model, load_inference_model
import paddle.v2.framework.executor as executor
import unittest
import numpy as np
class TestBook(unittest.TestCase):
def test_fit_line_inference_model(self):
MODEL_DIR = "./tmp/inference_model"
init_program = Program()
program = Program()
x = layers.data(
name='x',
shape=[2],
data_type='float32',
program=program,
init_program=init_program)
y = layers.data(
name='y',
shape=[1],
data_type='float32',
program=program,
init_program=init_program)
y_predict = layers.fc(input=x,
size=1,
act=None,
program=program,
init_program=init_program)
cost = layers.square_error_cost(
input=y_predict,
label=y,
program=program,
init_program=init_program)
avg_cost = layers.mean(
x=cost, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
for i in xrange(100):
x_data = np.array(
[[1, 1], [1, 2], [3, 4], [5, 2]]).astype("float32")
y_data = np.array([[-2], [-3], [-7], [-7]]).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)
tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
outs = exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
expected = np.array(outs[0])
reload(executor) # reload to build a new scope
exe = executor.Executor(place)
[infer_prog, feed_var_names, fetch_vars] = load_inference_model(
MODEL_DIR, exe)
outs = exe.run(
infer_prog,
feed={feed_var_names[0]: tensor_x,
feed_var_names[1]: tensor_y},
fetch_list=fetch_vars)
actual = np.array(outs[0])
self.assertEqual(feed_var_names, ["x", "y"])
self.assertEqual(len(fetch_vars), 1)
self.assertEqual(str(fetch_vars[0]), str(avg_cost))
self.assertEqual(expected, actual)
if __name__ == '__main__':
unittest.main()
...@@ -93,15 +93,15 @@ class TestBook(unittest.TestCase): ...@@ -93,15 +93,15 @@ class TestBook(unittest.TestCase):
dict_size = 10000 dict_size = 10000
embed_size = 32 embed_size = 32
first_word = layers.data( first_word = layers.data(
name='firstw', shape=[1], data_type='int32', program=program) name='firstw', shape=[1], data_type='int64', program=program)
second_word = layers.data( second_word = layers.data(
name='secondw', shape=[1], data_type='int32', program=program) name='secondw', shape=[1], data_type='int64', program=program)
third_word = layers.data( third_word = layers.data(
name='thirdw', shape=[1], data_type='int32', program=program) name='thirdw', shape=[1], data_type='int64', program=program)
forth_word = layers.data( forth_word = layers.data(
name='forthw', shape=[1], data_type='int32', program=program) name='forthw', shape=[1], data_type='int64', program=program)
next_word = layers.data( next_word = layers.data(
name='nextw', shape=[1], data_type='int32', program=program) name='nextw', shape=[1], data_type='int64', program=program)
embed_first = layers.embedding( embed_first = layers.embedding(
input=first_word, input=first_word,
......
...@@ -7,7 +7,7 @@ class TestLookupTableOp(OpTest): ...@@ -7,7 +7,7 @@ class TestLookupTableOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "lookup_table" self.op_type = "lookup_table"
table = np.random.random((17, 31)).astype("float32") table = np.random.random((17, 31)).astype("float32")
ids = np.random.randint(0, 17, 4).astype("int32") ids = np.random.randint(0, 17, 4).astype("int64")
ids_expand = np.expand_dims(ids, axis=1) ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand} self.inputs = {'W': table, 'Ids': ids_expand}
self.outputs = {'Out': table[ids]} self.outputs = {'Out': table[ids]}
......
...@@ -35,4 +35,6 @@ class LstmUnitTest(OpTest): ...@@ -35,4 +35,6 @@ class LstmUnitTest(OpTest):
if __name__ == "__main__": if __name__ == "__main__":
# FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
exit(0)
unittest.main() unittest.main()
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
import numpy
import paddle.v2 as paddle
exit(
0
) # FIXME(yuyang18): InferShape has been removed, this unittest should be changed until compile time is ready
BATCH_SIZE = 100
scope = core.Scope()
place = core.CPUPlace()
# if you want to test GPU training, you can use gpu place
# place = core.GPUPlace(0)
dev_ctx = core.DeviceContext.create(place)
init_net = core.Net.create()
forward_net = core.Net.create()
backward_net = None
optimize_net = core.Net.create()
def atomic_id():
id = 0
while True:
yield id
id += 1
uniq_id = atomic_id().next
def data_layer(name, dims):
var = scope.var(name)
tensor = var.get_tensor()
tensor.set_dims(dims) # 1 is batch size holder.
return name
def feed_data(name, data):
assert isinstance(data, numpy.ndarray)
tensor = scope.find_var(name).get_tensor()
tensor.set_dims(data.shape)
if data.dtype == numpy.dtype("int32"):
tensor.alloc_int(place)
elif data.dtype == numpy.dtype("float32"):
tensor.alloc_float(place)
else:
raise ValueError("data type not supported")
tensor.set(data, place)
def grad_var_name(var_name):
return var_name + "@GRAD"
def sgd_optimizer(net, param_name, learning_rate=0.005):
grad_name = grad_var_name(param_name)
optimize_op = Operator(
"sgd",
param=param_name,
grad=grad_name,
param_out=param_name,
learning_rate=learning_rate)
net.append_op(optimize_op)
# should use operator and add these to the init_network
def init_param(net, param_name, dims):
scope.var(param_name)
op = Operator(
"uniform_random", Out=param_name, dims=dims, min=-0.5, max=0.5, seed=10)
op.infer_shape(scope)
net.append_op(op)
# fc_layer
def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
"""
The fully connected layer.
:param input: The name of input variable.
:type input: str
:param size: The size of fully connected layer.
:param act: The name of activation.
:param param: The attribute of learnable parameter which can be used to
modify initialization mean and std of the parameter.
:param bias: The attribute of bias. If set False, this layer does not have
a bias.
:param name: The name of this layer. If it is not set explictly, a name
will be generated automatically.
:return: The name of the output variable.
"""
if name is None:
name = "fc_%d" % uniq_id()
if not isinstance(name, str):
raise ValueError("The name of a layer should be a string.")
input_dims = scope.find_var(input).get_tensor().get_dims()
w_name = param or name + ".w"
init_param(net=init_net, param_name=w_name, dims=[input_dims[1], size])
sgd_optimizer(net=optimize_net, param_name=w_name, learning_rate=0.01)
pre_activation = name + ".mul.out"
scope.var(pre_activation)
mul_op = Operator("mul", X=input, Y=w_name, Out=pre_activation)
net.append_op(mul_op)
# create bias variable if needed
if bias:
bias_name = name + ".b"
init_param(net=init_net, param_name=bias_name, dims=[size])
sgd_optimizer(
net=optimize_net, param_name=bias_name, learning_rate=0.001)
bias_out = name + ".rowwise_add.out"
scope.var(bias_out)
rowwise_append_op = Operator(
"rowwise_add", X=pre_activation, b=bias_name, Out=bias_out)
net.append_op(rowwise_append_op)
pre_activation = bias_out
activation_op = Operator(act, X=pre_activation, Y=name)
net.append_op(activation_op)
scope.var(name)
net.infer_shape(scope)
return name
def cross_entropy_layer(net, input, label):
cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator(
"cross_entropy", X=input, Label=label, Y=cost_name)
net.append_op(cross_entropy_op)
scope.var(cost_name)
net.infer_shape(scope)
return cost_name
def create_backward_net(forward_net):
net = core.Operator.backward(forward_net, set())
for input in net.inputs()["all"]:
var = scope.var(input)
var.get_tensor()
for output in net.outputs()["all"]:
var = scope.var(output)
var.get_tensor()
return net
def debug_print_op(op):
print("===============" + op.type() + "==============")
print("***inputs:***")
for input in op.inputs()["all"]:
print input, scope.find_var(input).get_tensor().get_dims()
print("\n***outputs:***")
for output in op.outputs()["all"]:
print output, scope.find_var(output).get_tensor().get_dims()
print("")
print("")
def set_cost(cost):
cost_shape = numpy.array(scope.find_var(cost).get_tensor()).shape
cost_grad = \
scope.find_var(grad_var_name(cost)).get_tensor()
cost_grad.set_dims(cost_shape)
cost_grad.alloc_float(place)
cost_grad.set(numpy.ones(cost_shape).astype("float32"), place)
def get_cost_mean(cost):
cost_data = numpy.array(scope.find_var(cost).get_tensor())
return cost_data.sum() / len(cost_data)
def error_rate(predict, label):
predict_var = numpy.array(scope.find_var(predict).get_tensor()).argmax(
axis=1)
label = numpy.array(scope.find_var(label).get_tensor())
error_num = numpy.sum(predict_var != label)
return error_num / float(len(label))
images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
labels = data_layer(name="label", dims=[BATCH_SIZE, 1])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
cost = cross_entropy_layer(net=forward_net, input=predict, label=labels)
init_net.complete_add_op(True)
forward_net.complete_add_op(True)
backward_net = create_backward_net(forward_net)
optimize_net.complete_add_op(True)
print(init_net)
print(forward_net)
print(backward_net)
print(optimize_net)
debug_print_op(forward_net)
debug_print_op(backward_net)
debug_print_op(optimize_net)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
def test(cost_name):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
cost = []
error = []
for data in test_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
forward_net.infer_shape(scope)
forward_net.run(scope, dev_ctx)
cost.append(get_cost_mean(cost_name))
error.append(error_rate(predict, "label"))
print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str(
sum(error) / float(len(error))))
PASS_NUM = 1
init_net.run(scope, dev_ctx)
for pass_id in range(PASS_NUM):
batch_id = 0
for data in train_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
forward_net.infer_shape(scope)
forward_net.run(scope, dev_ctx)
set_cost(cost)
backward_net.infer_shape(scope)
backward_net.run(scope, dev_ctx)
optimize_net.run(scope, dev_ctx)
if batch_id % 100 == 0:
print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]")
test(cost)
batch_id = batch_id + 1
...@@ -45,4 +45,6 @@ class TestModifiedHuberLossOp(OpTest): ...@@ -45,4 +45,6 @@ class TestModifiedHuberLossOp(OpTest):
if __name__ == '__main__': if __name__ == '__main__':
exit(0)
# FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
unittest.main() unittest.main()
此差异已折叠。
...@@ -49,9 +49,12 @@ class TestPool2d_Op(OpTest): ...@@ -49,9 +49,12 @@ class TestPool2d_Op(OpTest):
self.init_test_case() self.init_test_case()
self.init_op_type() self.init_op_type()
self.init_pool_type() self.init_pool_type()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool2D_forward_naive(input, self.ksize, self.strides, output = self.pool2D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings,
self.global_pool).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = { self.attrs = {
......
...@@ -99,6 +99,8 @@ class TestProgram(unittest.TestCase): ...@@ -99,6 +99,8 @@ class TestProgram(unittest.TestCase):
outputs={"Out": add_out}, outputs={"Out": add_out},
attrs={"x_num_col_dims": 1}) attrs={"x_num_col_dims": 1})
self.assertEqual(mul_op.idx, 0)
self.assertEqual(add_op.idx, 1)
param_to_grad = prog.append_backward(add_out, set()) param_to_grad = prog.append_backward(add_out, set())
def grad_name(name): def grad_name(name):
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册