提交 f08bf453 编写于 作者: X xzl

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_merge_model_scripts

......@@ -67,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Info("Could not find the pserver checkpoint.")
log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
......@@ -99,7 +99,7 @@ func main() {
candy.Must(err)
go func() {
log.Info("starting pserver", log.Ctx{"port": *port})
log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
......
......@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}
var cptr (*C.uchar)
if len(c) > 0 {
cptr = (*C.uchar)(&c[0])
} else {
log.Error("empty config", "param name", paramWithConfigs.Param.Name)
}
o.config = c
o.opt = C.paddle_create_optimizer(
(*C.uchar)(&c[0]),
cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
......
......@@ -17,12 +17,11 @@ package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
......@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
// RPC error message.
const (
......@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"`
CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}
......@@ -92,7 +91,7 @@ type Service struct {
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
client KVStore
mu sync.Mutex
optMap map[string]*optimizer
......@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte
}
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
type KVStore interface {
GetKey(key string, timeout time.Duration) ([]byte, error)
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}
func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
......@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")
......@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err
}
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
crc32 := crc32.ChecksumIEEE(content)
if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}
......@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
......@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
close(s.initialized)
}
return s, nil
}
......@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t {
err := s.checkpoint()
if err != nil {
log.Error("finish init params error", log.Ctx{"error": err})
log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
......@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
......@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Info("Do not have existing checkpoint.")
log.Info("old meta not found, skip removing old meta")
err = nil
} else if err == nil {
log.Info("removing old meta")
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
}
if err != nil {
return
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
CRC32: crc32,
Path: p,
}
......@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
return
}
package pserver
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const testDir = "./test_data"
type myKV struct {
m map[string][]byte
}
func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
if m.m == nil {
m.m = make(map[string][]byte)
}
return m.m[key], nil
}
func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
if m.m == nil {
m.m = make(map[string][]byte)
}
m.m[key] = value
return nil
}
func TestCheckpoint(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
err = s.checkpoint()
assert.Nil(t, err)
_, err = LoadCheckpoint(kv, 0)
assert.Nil(t, err)
}
func float32ToByte(f float32) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, f)
if err != nil {
fmt.Println("binary.Write failed:", err)
}
return buf.Bytes()
}
func TestCheckpointWithData(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
var content []byte
for i := 0; i < 50000; i++ {
content = append(content, float32ToByte(float32(i))...)
}
p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
assert.Nil(t, err)
err = s.FinishInitParams(0, nil)
assert.Nil(t, err)
var p2 Parameter
err = s.GetParam(p1.Name, &p2)
assert.Nil(t, err)
assert.Equal(t, p1, p2)
err = s.checkpoint()
assert.Nil(t, err)
cp, err := LoadCheckpoint(kv, 0)
assert.Nil(t, err)
s1, err := NewService(0, time.Hour, testDir, kv, cp)
assert.Nil(t, err)
var p3 Parameter
err = s1.GetParam(p1.Name, &p3)
assert.Nil(t, err)
assert.Equal(t, p1, p3)
}
......@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait()
}
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}
......@@ -26,7 +26,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......@@ -42,7 +42,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
......
......@@ -315,6 +315,7 @@ static void CreateGradVarInBlock(
return false; /* not break */
});
if (need_infer_shape) {
ops[op_index]->InferVarType(block_desc);
ops[op_index]->InferShape(*block_desc);
}
}
......@@ -452,11 +453,16 @@ ParamGradInfoMap AppendBackward(
std::transform(target_shape_desc.begin(), target_shape_desc.end(),
std::back_inserter(target_shape),
[](int64_t dim) { return static_cast<int>(dim); });
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
std::unique_ptr<OpDescBind> fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", target_shape},
{"value", static_cast<float>(1.0)},
{"data_type", framework::DataType::FP32}}));
{"data_type", target.GetDataType()}}));
// infer var type of fill_one_op
fill_one_op->InferVarType(root_block);
root_block->AppendAllocatedOp(std::move(fill_one_op));
size_t forward_op_num = root_block->OpSize();
size_t forward_block_num = program_desc.Size();
......@@ -475,8 +481,7 @@ ParamGradInfoMap AppendBackward(
std::unordered_map<std::string, GradVarInfo> retv;
auto var = root_block->Var(fill_one_op_out);
// FIXME(qiao) infer the data type
var->SetDataType(framework::DataType::FP32);
var->SetDataType(target.GetDataType());
var->SetShape(target.Shape());
auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out;
......
......@@ -21,6 +21,8 @@
#include "paddle/framework/var_desc.h"
#include "paddle/operators/net_op.h"
USE_OP(fill_constant);
namespace paddle {
namespace framework {
......
......@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc));
}
for (const OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog));
}
}
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
......
......@@ -36,8 +36,7 @@ class ProgramDescBind;
class BlockDescBind {
public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog);
......
......@@ -28,7 +28,8 @@ enum OpInfoFillType {
kOperator = 0,
kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2,
kVarTypeInference = 3
kVarTypeInference = 3,
kShapeInference = 4
};
template <typename T>
......@@ -42,7 +43,10 @@ struct OpInfoFillTypeID {
? kGradOpDescMaker
: (std::is_base_of<VarTypeInference, T>::value
? kVarTypeInference
: static_cast<OpInfoFillType>(-1))));
: (std::is_base_of<InferShapeBase, T>::value
? kShapeInference
: static_cast<OpInfoFillType>(
-1)))));
}
};
......@@ -121,6 +125,16 @@ struct OpInfoFiller<T, kVarTypeInference> {
}
};
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_shape_ = [](InferShapeContext* ctx) {
T inference;
inference(ctx);
};
}
};
} // namespace details
} // namespace framework
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <set>
#include <vector>
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
......@@ -56,6 +57,22 @@ Executor::~Executor() {
}
}
static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == VarDesc::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == VarDesc::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else {
PADDLE_THROW(
"Variable type must be "
"LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
......@@ -69,10 +86,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
for (auto& var : block.vars()) {
if (var.persistable()) {
auto* ptr = scope->Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope.Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr;
}
......
......@@ -14,9 +14,13 @@ limitations under the License. */
#include "paddle/framework/op_desc.h"
#include <functional>
#include <mutex>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
#include "glog/logging.h"
namespace paddle {
namespace framework {
......@@ -24,16 +28,47 @@ namespace framework {
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
op_desc_.set_type(type);
desc_.set_type(type);
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore outputs_
int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
}
}
OpDesc *OpDescBind::Proto() {
Flush();
return &op_desc_;
return &desc_;
}
const std::vector<std::string> &OpDescBind::Input(
......@@ -167,23 +202,23 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void OpDescBind::Flush() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
auto *input = op_desc_.add_inputs();
auto *input = desc_.add_inputs();
input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments());
}
this->op_desc_.mutable_outputs()->Clear();
this->desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) {
auto *output = op_desc_.add_outputs();
auto *output = desc_.add_outputs();
output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments());
}
this->op_desc_.mutable_attrs()->Clear();
this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) {
auto *attr_desc = op_desc_.add_attrs();
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
......@@ -195,26 +230,26 @@ void OpDescBind::Flush() {
}
}
using InferShapeFuncMap =
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>;
static InferShapeFuncMap &InferShapeFuncs() {
static InferShapeFuncMap *g_map = nullptr;
if (g_map == nullptr) {
g_map = new InferShapeFuncMap();
auto &info_map = OpInfoMap::Instance();
// all registered kernels
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
static std::once_flag init_infer_shape_funcs;
static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] {
auto &map = OpInfoMap::Instance();
auto &info_map = *map.mutable_map();
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto op_type = kern_pair.first;
auto &op_info = info_map.at(op_type);
auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
g_map->insert(
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
static_cast<OperatorWithKernel *>(op_info.Creator()("", {}, {}, {}));
if (op_info.infer_shape_) { // infer_shape has been registered.
continue;
}
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx);
};
}
}
return *g_map;
});
}
void OpDescBind::CheckAttrs() {
......@@ -230,13 +265,13 @@ void OpDescBind::CheckAttrs() {
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type());
if (it == funcs.end()) {
PADDLE_THROW("Operator %s has not been registered", this->Type());
}
VLOG(3) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx);
infer_shape(&ctx);
}
void OpDescBind::InferVarType(BlockDescBind *block) const {
......
......@@ -24,6 +24,7 @@ namespace paddle {
namespace framework {
class BlockDescBind;
class ProgramDescBind;
class OpDescBind {
public:
......@@ -32,11 +33,13 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto();
std::string Type() const { return op_desc_.type(); }
std::string Type() const { return desc_.type(); }
void SetType(const std::string &type) { op_desc_.set_type(type); }
void SetType(const std::string &type) { desc_.set_type(type); }
const std::vector<std::string> &Input(const std::string &name) const;
......@@ -117,7 +120,7 @@ class OpDescBind {
return ret_val;
}
OpDesc op_desc_;
OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
......
......@@ -25,12 +25,19 @@
namespace paddle {
namespace framework {
class InferShapeBase {
public:
virtual ~InferShapeBase() = default;
virtual void operator()(InferShapeContext*) const = 0;
};
struct OpInfo {
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_;
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
......@@ -87,13 +94,13 @@ class OpInfoMap {
}
}
const std::unordered_map<std::string, const OpInfo>& map() const {
return map_;
}
const std::unordered_map<std::string, OpInfo>& map() const { return map_; }
std::unordered_map<std::string, OpInfo>* mutable_map() { return &map_; }
private:
OpInfoMap() = default;
std::unordered_map<std::string, const OpInfo> map_;
std::unordered_map<std::string, OpInfo> map_;
DISABLE_COPY_AND_ASSIGN(OpInfoMap);
};
......
......@@ -33,24 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif
const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return var->GetMutable<Tensor>();
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
......@@ -204,6 +186,30 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
t = &(var->Get<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
}
static Tensor* GetMutableTensorFromVar(Variable* var) {
Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
t = var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
......@@ -227,7 +233,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>();
return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
}
template <>
......@@ -240,7 +246,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr
: var->GetMutable<LoDTensor>();
: GetMutableTensorFromVar(var);
});
return res;
}
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
......@@ -60,9 +61,6 @@ inline std::string GradVarName(const std::string& var_name) {
class OperatorBase;
class ExecutionContext;
extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -414,7 +412,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
private:
DDim GetDim(const std::string& name) const override {
return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
return framework::make_ddim(var->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
......@@ -511,28 +511,26 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
private:
template <bool Allocate>
Tensor* GetTensor(const std::string& name) const {
Tensor* t = nullptr;
auto* var = scope_.FindVar(name);
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
if (Allocate) {
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable(%s) should be tensor", name);
}
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
t = GetTensorFromVar(scope_.FindVar(name));
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
}
DDim GetDim(const std::string& name) const override {
return GetTensor<false>(name)->dims();
}
void SetDim(const std::string& name, const DDim& dim) override {
GetTensor<true>(name)->Resize(dim);
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
const OperatorBase& op_;
......@@ -638,7 +636,9 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual void InferShape(InferShapeContext* ctx) const = 0;
virtual void InferShape(InferShapeContext* ctx) const {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
}
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
......@@ -655,11 +655,14 @@ class OperatorWithKernel : public OperatorBase {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same.");
"DataType of Paddle Op %s must be same.", Type());
data_type = tmp;
}
}
......
......@@ -237,12 +237,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope;
scope.Var("x0")->GetMutable<Tensor>();
scope.Var("x1")->GetMutable<Tensor>();
scope.Var("x2")->GetMutable<Tensor>();
scope.Var("k0")->GetMutable<Tensor>();
scope.Var("y0")->GetMutable<Tensor>();
scope.Var("y1")->GetMutable<Tensor>();
scope.Var("x0")->GetMutable<LoDTensor>();
scope.Var("x1")->GetMutable<LoDTensor>();
scope.Var("x2")->GetMutable<LoDTensor>();
scope.Var("k0")->GetMutable<LoDTensor>();
scope.Var("y0")->GetMutable<LoDTensor>();
scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context);
......
......@@ -19,9 +19,9 @@ namespace paddle {
namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_.add_blocks();
auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID());
b->set_idx(prog_.blocks_size() - 1);
b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
......@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
return &prog_;
return &desc_;
}
ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add();
auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block));
}
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
prog_ = o.prog_;
desc_ = o.desc_;
for (int i = 0; i < prog_.blocks_size(); ++i) {
auto *block = prog_.mutable_blocks(i);
for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
}
}
ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}
} // namespace framework
} // namespace paddle
......@@ -31,6 +31,8 @@ class ProgramDescBind {
ProgramDescBind(const ProgramDescBind &o);
explicit ProgramDescBind(const std::string &binary_str);
BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
......@@ -40,7 +42,7 @@ class ProgramDescBind {
ProgramDesc *Proto();
private:
ProgramDesc prog_;
ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
};
......
......@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) {
};
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
ASSERT_EQ(3, global_block_copy->LocalVarNames().size());
ASSERT_EQ(3UL, global_block_copy->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
......@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) {
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
std::string binary_str;
program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0);
ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
ASSERT_TRUE(global_block_restored->HasVar(name));
auto* restored = global_block_restored->Var(name);
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
ASSERT_EQ(global_block->LocalVarNames(),
global_block_restored->LocalVarNames());
ASSERT_EQ(3UL, global_block_restored->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_restored = global_block->Op(i);
ASSERT_EQ(op_origin->Type(), op_restored->Type());
ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_restored->Outputs());
ASSERT_EQ(op_restored->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
}
} // namespace framework
} // namespace paddle
......@@ -23,7 +23,10 @@ class SelectedRows {
value_.reset(new Tensor());
}
SelectedRows() { value_.reset(new Tensor()); }
SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
}
platform::Place place() const { return value_->place(); }
......@@ -37,6 +40,8 @@ class SelectedRows {
const Vector<int64_t>& rows() const { return rows_; }
Vector<int64_t>* mutable_rows() { return &rows_; }
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
DDim GetCompleteDims() const {
......
......@@ -28,6 +28,8 @@ class OperatorBase;
class OpDescBind;
class BlockDescBind;
class BlockDesc;
class InferShapeContext;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
......@@ -49,5 +51,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/,
BlockDescBind* /*block*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>;
} // namespace framework
} // namespace paddle
......@@ -59,6 +59,8 @@ class VarDescBind {
desc_.set_type(VarDesc::LOD_TENSOR);
}
explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}
VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); }
......
......@@ -132,7 +132,7 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(sequence_conv_op DEPS context_project)
......
......@@ -446,12 +446,16 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp,
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker<float>,
hard_sigmoid_grad, ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>); \
REGISTER_OP_CPU_KERNEL(act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::CPUPlace, \
ops::grad_functor<float>>);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CPUPlace, \
ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CPUPlace, \
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -17,12 +17,16 @@
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \
act_type, \
ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>); \
REGISTER_OP_GPU_KERNEL(act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::GPUPlace, \
ops::grad_functor<float>>);
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \
act_type, \
ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::GPUPlace, \
ops::functor<double>>); \
REGISTER_OP_GPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::GPUPlace, \
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
......@@ -210,8 +210,8 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval();
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2);
}
};
......@@ -226,8 +226,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval();
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};
......@@ -243,9 +243,10 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp1 = (x > lambda).template cast<T>().eval();
auto temp2 = (x < -lambda).template cast<T>().eval();
y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda);
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
......@@ -257,8 +258,9 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x > lambda).template cast<T>().eval();
auto temp2 = (x < -lambda).template cast<T>().eval();
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};
......@@ -362,7 +364,8 @@ struct BReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max);
y.device(d) =
x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
}
};
......@@ -375,7 +378,9 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
dx.device(d) = dy *
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>();
}
};
......@@ -390,7 +395,8 @@ struct Relu6Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)).cwiseMin(threshold);
y.device(d) =
x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
}
};
......@@ -402,8 +408,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy * ((x > static_cast<T>(0)) * (x < threshold)).template cast<T>();
dx.device(d) = dy *
((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
.template cast<T>();
}
};
......@@ -463,7 +470,8 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold);
auto tmp = static_cast<T>(threshold);
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast<T>(1) + temp.exp()).log();
}
};
......@@ -476,7 +484,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval();
auto tmp = static_cast<T>(threshold);
auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
}
};
......@@ -490,7 +499,7 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(alpha * x);
y.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
}
};
......@@ -502,7 +511,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = alpha * (x < static_cast<T>(0)).template cast<T>().eval();
auto temp1 = static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
......@@ -517,9 +527,9 @@ struct ELUFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
x.cwiseMax(static_cast<T>(0)) +
(alpha * (x.exp() - static_cast<T>(1))).cwiseMin(static_cast<T>(0));
y.device(d) = x.cwiseMax(static_cast<T>(0)) +
(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
.cwiseMin(static_cast<T>(0));
}
};
......@@ -531,9 +541,9 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + alpha) * (x < static_cast<T>(0)).template cast<T>();
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + static_cast<T>(alpha)) *
(x < static_cast<T>(0)).template cast<T>();
}
};
......@@ -545,7 +555,7 @@ struct PowFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(factor);
y.device(d) = x.pow(static_cast<T>(factor));
}
};
......@@ -557,7 +567,8 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * factor * x.pow(factor - static_cast<T>(1));
dx.device(d) = dy * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1)));
}
};
......@@ -571,7 +582,8 @@ struct STanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = scale_b * (scale_a * x).tanh();
y.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
}
};
......@@ -585,8 +597,10 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh();
dx.device(d) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dy * a * b * (static_cast<T>(1) - temp);
}
};
......@@ -599,7 +613,8 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = (x > static_cast<T>(threshold)).template cast<T>() * x;
auto th = static_cast<T>(threshold);
y.device(d) = (x > th).template cast<T>() * x;
}
};
......@@ -612,7 +627,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(threshold)).template cast<T>();
auto th = static_cast<T>(threshold);
dx.device(d) = dy * (x > th).template cast<T>();
}
};
......
......@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == 1) {
if (ctx->Attrs().Get<bool>("is_training") == true) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
......@@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
......@@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1);
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims,
......
......@@ -33,7 +33,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob");
if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask");
......@@ -41,7 +41,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
int seed = context.Attr<int>("seed");
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<AttrType> dist(0, 1);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
......
......@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
dst_item.set_lod(src_item.lod());
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
......
......@@ -64,5 +64,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
ops::FillConstantOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>);
fill_constant, ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>);
......@@ -18,5 +18,6 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>);
fill_constant, ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int>);
......@@ -25,7 +25,7 @@ class FillConstantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<T>("value");
auto value = ctx.Attr<float>("value");
auto out_eigen = framework::EigenVector<T>::Flatten(*out);
auto place = ctx.GetEigenDevice<Place>();
......
......@@ -171,8 +171,7 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto bias = Input("Bias");
if (bias != framework::kEmptyVarName) {
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
......@@ -203,6 +202,8 @@ namespace ops = paddle::operators;
REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad,
ops::GRUUnitGradOp);
REGISTER_OP_CPU_KERNEL(gru_unit,
ops::GRUUnitKernel<paddle::platform::CPUPlace, float>);
ops::GRUUnitKernel<paddle::platform::CPUPlace, float>,
ops::GRUUnitKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::CPUPlace, float>);
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::CPUPlace, float>,
ops::GRUUnitGradKernel<paddle::platform::CPUPlace, double>);
......@@ -17,6 +17,8 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gru_unit,
ops::GRUUnitKernel<paddle::platform::GPUPlace, float>);
ops::GRUUnitKernel<paddle::platform::GPUPlace, float>,
ops::GRUUnitKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::GPUPlace, float>);
gru_unit_grad, ops::GRUUnitGradKernel<paddle::platform::GPUPlace, float>,
ops::GRUUnitGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/l1_norm_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class L1NormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
ctx->SetOutputDim("Out", {1});
}
};
class L1NormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
L1NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op.");
AddComment(R"DOC(
L1 Norm Operator.
Computes the L1 norm of a tensor.
Out = sum (abs(X))
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(l1_norm, ops::L1NormOp, ops::L1NormOpMaker, l1_norm_grad,
ops::L1NormGradOp);
REGISTER_OP_CPU_KERNEL(l1_norm,
ops::L1NormKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
l1_norm_grad, ops::L1NormGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/l1_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(l1_norm,
ops::L1NormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
l1_norm_grad, ops::L1NormGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
// Out = sum(abs(X))
template <typename Place, typename T>
class L1NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.abs().sum();
}
};
// dX = dout * sign(X)
template <typename Place, typename T>
class L1NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *x = context.Input<framework::Tensor>("X");
const framework::Tensor *d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(d_out->numel() == 1, "L1 Norm Gradient should be scalar");
framework::Tensor *dx =
context.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(context.GetPlace());
auto x_eigen = framework::EigenVector<T>::Flatten(*x);
auto d_out_eigen = framework::EigenVector<T>::Flatten(*d_out);
auto dx_eigen = framework::EigenVector<T>::Flatten(*dx);
auto place = context.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> x_dsize(x->numel());
dx_eigen.device(place) = d_out_eigen.broadcast(x_dsize) * x_eigen.sign();
}
};
} // namespace operators
} // namespace paddle
......@@ -68,6 +68,7 @@ struct SelectedRowsAdd<platform::CPUPlace, T> {
};
template struct SelectedRowsAdd<platform::CPUPlace, float>;
template struct SelectedRowsAdd<platform::CPUPlace, double>;
template <typename T>
struct SelectedRowsAddTensor<platform::CPUPlace, T> {
......@@ -108,6 +109,72 @@ struct SelectedRowsAddTensor<platform::CPUPlace, T> {
};
template struct SelectedRowsAddTensor<platform::CPUPlace, float>;
template struct SelectedRowsAddTensor<platform::CPUPlace, double>;
template <typename T>
struct SelectedRowsAddTo<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const int64_t input2_offset,
framework::SelectedRows* input2) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2->height());
auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows());
auto& in1_value = input1.value();
auto* in2_value = input2->mutable_value();
// concat rows
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end());
auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
auto in2_place = input2->place();
PADDLE_ENFORCE(platform::is_cpu_place(in2_place));
auto* in1_data = in1_value.data<T>();
auto* in2_data = in2_value->data<T>();
memory::Copy(boost::get<platform::CPUPlace>(in2_place),
in2_data + input2_offset,
boost::get<platform::CPUPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T));
}
};
template struct SelectedRowsAddTo<platform::CPUPlace, float>;
template struct SelectedRowsAddTo<platform::CPUPlace, double>;
template <typename T>
struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) {
for (int64_t j = 0; j < in1_row_numel; j++) {
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
}
}
}
};
template struct SelectedRowsAddToTensor<platform::CPUPlace, float>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -73,12 +73,13 @@ struct SelectedRowsAdd<platform::GPUPlace, T> {
};
template struct SelectedRowsAdd<platform::GPUPlace, float>;
template struct SelectedRowsAdd<platform::GPUPlace, double>;
namespace {
template <typename T>
template <typename T, int block_size>
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
const int64_t* rows, T* tensor_out,
int64_t row_numel, int block_size) {
int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
......@@ -119,14 +120,13 @@ struct SelectedRowsAddTensor<platform::GPUPlace, T> {
SetConstant<platform::GPUPlace, T> functor;
functor(context, output, 0.0);
int block_size = 256;
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in1_rows.size());
SelectedRowsAddTensorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in1_data, in1_rows.data(), out_data,
in1_row_numel, block_size);
SelectedRowsAddTensorKernel<T, block_size><<<
grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in1_data, in1_rows.data(), out_data, in1_row_numel);
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
......@@ -136,6 +136,93 @@ struct SelectedRowsAddTensor<platform::GPUPlace, T> {
};
template struct SelectedRowsAddTensor<platform::GPUPlace, float>;
template struct SelectedRowsAddTensor<platform::GPUPlace, double>;
template <typename T>
struct SelectedRowsAddTo<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const int64_t input2_offset,
framework::SelectedRows* input2) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2->height());
auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows());
auto& in1_value = input1.value();
auto* in2_value = input2->mutable_value();
// concat rows
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end());
auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
auto in2_place = input2->place();
PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
auto* in1_data = in1_value.data<T>();
auto* in2_data = in2_value->data<T>();
memory::Copy(
boost::get<platform::GPUPlace>(in2_place), in2_data + input2_offset,
boost::get<platform::GPUPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
}
};
template struct SelectedRowsAddTo<platform::GPUPlace, float>;
template struct SelectedRowsAddTo<platform::GPUPlace, double>;
namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
const int64_t* rows,
T* tensor_out,
int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
}
}
} // namespace
template <typename T>
struct SelectedRowsAddToTensor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* in2_data = input2->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in1_rows.size());
SelectedRowsAddToTensorKernel<T, block_size><<<
grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in1_data, in1_rows.data(), in2_data, in1_row_numel);
}
};
template struct SelectedRowsAddToTensor<platform::GPUPlace, float>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -36,6 +36,22 @@ struct SelectedRowsAddTensor {
const framework::Tensor& input2, framework::Tensor* output);
};
// input2 = input1 + input2
template <typename Place, typename T>
struct SelectedRowsAddTo {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const int64_t input2_offset, framework::SelectedRows* input2);
};
// input2 = input1 + input2
template <typename Place, typename T>
struct SelectedRowsAddToTensor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
framework::Tensor* input2);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -104,3 +104,91 @@ TEST(selected_rows_functor, cpu_add) {
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0);
}
TEST(selected_rows_functor, cpu_add_to) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
CPUPlace cpu_place;
CPUDeviceContext ctx(cpu_place);
SetConstant<CPUPlace, float> functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
output->set_height(height);
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
SelectedRowsAddTo<CPUPlace, float> add_to_functor;
add_to_functor(ctx, *selected_rows1, 0, output.get());
add_to_functor(ctx, *selected_rows2, in1_value->numel(), output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
auto* out_data = output->value().data<float>();
// input1 value
EXPECT_EQ(out_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
functor(ctx, tensor1.get(), 3.0);
SelectedRowsAddToTensor<CPUPlace, float> add_to_tensor_functor;
add_to_tensor_functor(ctx, *output, tensor1.get());
auto* tensor1_data = tensor1->data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor1_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor1_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor1_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor1_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor1_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor1_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0);
}
......@@ -113,3 +113,100 @@ TEST(selected_rows_functor, gpu_add) {
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0);
}
TEST(selected_rows_functor, gpu_add_to) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
GPUPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<GPUPlace, float> functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), gpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), gpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
output->set_height(height);
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), gpu_place);
SelectedRowsAddTo<GPUPlace, float> add_to_functor;
add_to_functor(ctx, *selected_rows1, 0, output.get());
add_to_functor(ctx, *selected_rows2, in1_value->numel(), output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
out_cpu.CopyFrom(*out_value, cpu_place, ctx);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
// input1 value
EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
functor(ctx, tensor1.get(), 3.0);
SelectedRowsAddToTensor<GPUPlace, float> add_to_tensor_functor;
add_to_tensor_functor(ctx, *output, tensor1.get());
Tensor tensor1_cpu;
tensor1_cpu.CopyFrom(*tensor1, cpu_place, ctx);
ctx.Wait();
auto* tensor1_cpu_data = tensor1_cpu.data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor1_cpu_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor1_cpu_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor1_cpu_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor1_cpu_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor1_cpu_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor1_cpu_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor1_cpu_data[9 * row_numel + 6], 5.0);
}
......@@ -71,7 +71,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<paddle::platform::CPUPlace, float>,
ops::MeanKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
ops::MeanGradKernel<paddle::platform::CPUPlace, float>,
ops::MeanGradKernel<paddle::platform::CPUPlace, double>);
......@@ -17,7 +17,8 @@
#include "paddle/operators/mean_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<paddle::platform::GPUPlace, float>,
ops::MeanKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::GPUPlace, float>);
ops::MeanGradKernel<paddle::platform::GPUPlace, float>,
ops::MeanGradKernel<paddle::platform::GPUPlace, double>);
......@@ -19,11 +19,9 @@ namespace operators {
using framework::Tensor;
class MulOp : public framework::OperatorWithKernel {
class MulOpShapeInference : public framework::InferShapeBase {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -137,7 +135,10 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OPERATOR(mul, paddle::framework::OperatorWithKernel, ops::MulOpMaker,
ops::MulOpShapeInference,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/proximal_adagrad_op.h"
namespace paddle {
namespace operators {
class ProximalAdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
"Input(LearningRate) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of ProximalAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("MomentOut"),
"Output(MomentOut) of ProximalAdagradOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad of ProximalAdagrad Op must have same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Moment"),
"Param and Moment of ProximalAdagrad Op must have same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
}
};
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalAdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated.");
AddInput("Moment",
"(Tensor, default Tensor<float>) "
"Moment parameter that has to be updated.");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter.");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1.");
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
AddOutput("MomentOut", "(Tensor) Output updated moment value.");
AddAttr<float>("l1",
"(float, default 0.0) "
"L1 regularization strength.")
.SetDefault(0.0f);
AddAttr<float>("l2",
"(float, default 0.0)"
"L2 regularization strength.")
.SetDefault(0.0f);
AddComment(R"DOC(
Optimizer that implements the proximal adagrad algorithm.
moment = moment + grad * grad
prox_param = param - learning_rate * grad * (1 / sqrt(moment))
param = sign(prox_param) / (1 + learning_rate * l2) *
max { |prox_param| - learning_rate * l1 , 0 }
The paper that proposed Proximal GD:
(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf)
Here, we use the adagrad learning rate as specified here:
(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(proximal_adagrad, ops::ProximalAdagradOp,
ops::ProximalAdagradOpMaker);
REGISTER_OP_CPU_KERNEL(
proximal_adagrad,
ops::ProximalAdagradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/proximal_adagrad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
proximal_adagrad,
ops::ProximalAdagradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class ProximalAdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
auto l1 = static_cast<T>(ctx.Attr<float>("l1"));
auto l2 = static_cast<T>(ctx.Attr<float>("l2"));
auto grad = ctx.Input<Tensor>("Grad");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto g = EigenVector<T>::Flatten(*grad);
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
m_out.device(place) = m + g * g;
auto prox_param = p - lr.broadcast(grad_dsize) * g / m_out.sqrt();
if (l1 > static_cast<T>(0)) {
p_out.device(place) =
prox_param.sign() *
(((prox_param.abs() - (lr * l1).broadcast(grad_dsize))
.cwiseMax(static_cast<T>(0.0))) /
(static_cast<T>(1.0) + (lr * l2).broadcast(grad_dsize)));
} else {
p_out.device(place) =
prox_param / (static_cast<T>(1.0) + (lr * l2).broadcast(grad_dsize));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -73,4 +73,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>);
ops::ScaleKernel<paddle::platform::CPUPlace, float>,
ops::ScaleKernel<paddle::platform::CPUPlace, double>);
......@@ -15,4 +15,5 @@
#include "paddle/operators/scale_op.h"
REGISTER_OP_GPU_KERNEL(
scale, paddle::operators::ScaleKernel<paddle::platform::GPUPlace, float>);
scale, paddle::operators::ScaleKernel<paddle::platform::GPUPlace, float>,
paddle::operators::ScaleKernel<paddle::platform::GPUPlace, double>);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace operators {
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class ScaleKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
......@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.Attr<AttrType>("scale"));
auto scale = static_cast<T>(context.Attr<float>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
......@@ -23,18 +23,21 @@ using Tensor = framework::Tensor;
namespace {
template <typename T>
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
const int* labels, const int batch_size,
const int class_num) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = tid / class_num;
if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
__syncthreads();
if (tid < batch_size) {
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
out_grad[tid * class_num + labels[tid]] -= 1.;
logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
}
__syncthreads();
if (tid < batch_size * class_num) {
logit_grad[tid] *= loss_grad[sample_idx];
}
}
......@@ -47,7 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < batch_size * class_num) {
int row_ids = ids / class_num;
logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
logit_grad[ids] = logit_grad[ids] * (loss_grad[row_ids] - labels[ids]);
}
}
} // namespace
......
......@@ -67,8 +67,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) -
lbl_mat;
(out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) -
lbl_mat);
} else {
const int batch_size = logit_grad->dims()[0];
const int* label_data = labels->data<int>();
......@@ -78,7 +78,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
logit_grad_data[index] =
(out_grad_data[i] * logit_grad_data[index] - 1.);
out_grad_data[i] * (logit_grad_data[index] - 1.);
}
}
}
......
......@@ -95,17 +95,18 @@ class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class SplitOpGrad : public NetOp {
class SplitGradMaker : public framework::SingleGradOpDescMaker {
public:
SplitOpGrad(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
auto out_grad = Inputs(framework::GradVarName("Out"));
auto x_grad = Output(framework::GradVarName("X"));
AppendOp(framework::OpRegistry::CreateOp("concat", {{"X", out_grad}},
{{"Out", {x_grad}}}, attrs));
CompleteAddOp(false);
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto op = new framework::OpDescBind();
op->SetType("concat");
op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(op);
}
};
......@@ -114,7 +115,7 @@ class SplitOpGrad : public NetOp {
namespace ops = paddle::operators;
USE_CPU_ONLY_OP(concat);
REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad,
ops::SplitOpGrad);
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker);
REGISTER_OP_CPU_KERNEL(split,
ops::SplitOpKernel<paddle::platform::CPUPlace, float>);
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
#include "paddle/framework/var_type_inference.h"
#include "paddle/operators/net_op.h"
namespace paddle {
......@@ -55,6 +56,26 @@ or not. But the output only shares the LoD with the first input.
}
};
class SumOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
auto& inputs = op_desc.Input("X");
auto default_var_type = framework::VarDesc::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->Var(name)->GetType() == framework::VarDesc::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = framework::VarDesc::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
}
};
class SumGradMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
......@@ -83,5 +104,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
ops::SumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>,
ops::SumKernel<paddle::platform::CPUPlace, double>);
......@@ -13,4 +13,5 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>,
ops::SumKernel<paddle::platform::GPUPlace, double>);
......@@ -12,11 +12,15 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
......@@ -25,19 +29,68 @@ template <typename Place, typename T>
class SumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ins = context.MultiInput<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
auto result = EigenVector<T>::Flatten(*out);
int N = ins.size();
auto in = EigenVector<T>::Flatten(*(ins[0]));
result.device(place) = in;
for (int i = 1; i < N; i++) {
auto in = EigenVector<T>::Flatten(*(ins[i]));
result.device(place) = result + in;
auto& in_vars = context.MultiInputVar("X");
int N = in_vars.size();
auto out_var = context.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
auto* out = context.Output<Tensor>("Out");
// Runtime InferShape
for (int i = 0; i < N; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
out->Resize(in_vars[i]->Get<framework::LoDTensor>().dims());
break;
}
}
out->mutable_data<T>(context.GetPlace());
auto result = EigenVector<T>::Flatten(*out);
math::SetConstant<Place, T> constant_functor;
constant_functor(context.device_context(), out, 0.0);
math::SelectedRowsAddToTensor<Place, T> functor;
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < N; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto& in_t = in_vars[i]->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
} else if (in_vars[i]->IsType<framework::SelectedRows>()) {
auto& in_t = in_vars[i]->Get<framework::SelectedRows>();
functor(context.device_context(), in_t, out);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
} else if (out_var->IsType<framework::SelectedRows>()) {
auto* out = context.Output<SelectedRows>("Out");
auto* out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (int i = 0; i < N; i++) {
first_dim += in_vars[i]->Get<SelectedRows>().rows().size();
}
auto in_dim = in_vars[0]->Get<SelectedRows>().value().dims();
auto in_dim_vec = framework::vectorize(in_dim);
in_dim_vec[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim_vec));
out_value->mutable_data<T>(context.GetPlace());
math::SelectedRowsAddTo<Place, T> functor;
int64_t offset = 0;
for (int i = 0; i < N; i++) {
PADDLE_ENFORCE_EQ(out->height(),
in_vars[i]->Get<SelectedRows>().height())
functor(context.device_context(), in_vars[i]->Get<SelectedRows>(),
offset, out);
offset += in_vars[i]->Get<SelectedRows>().value().numel();
}
}
}
};
......
......@@ -105,6 +105,11 @@ void BindProgramDesc(py::module &m) {
[](ProgramDescBind &self, const ProgramDescBind &other) {
new (&self) ProgramDescBind(other);
})
.def("__init__",
[](ProgramDescBind &self, const py::bytes &binary_str) {
std::string str(binary_str);
new (&self) ProgramDescBind(str);
})
.def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("append_backward",
......
......@@ -110,43 +110,10 @@ void NewRemoteParameterUpdater::init(
// overwrite optimizerConfigV2 for per-parameter(layer) configs
for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig();
if (paramConfig.has_momentum() &&
trainerConfig_.learning_method() == "momentum") {
optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
}
if (paramConfig.has_learning_rate()) {
switch (optimizerConfigV2.lr_policy()) {
case 0:
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
case 1:
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
}
}
if (paramConfig.has_decay_rate()) {
switch (optimizerConfigV2.optimizer()) {
case 1: // SGD
optimizerConfigV2.mutable_sgd()->set_decay(
paramConfig.decay_rate());
break;
case 2: // Adadelta
optimizerConfigV2.mutable_adadelta()->set_decay(
paramConfig.decay_rate());
break;
case 3: // Adagrad
optimizerConfigV2.mutable_adagrad()->set_decay(
paramConfig.decay_rate());
break;
case 4: // Adam
optimizerConfigV2.mutable_adam()->set_decay(
paramConfig.decay_rate());
break;
}
}
// FIXME(typhoonzero): paramConfig always have default values,
// how to check if it's default?
// TODO(typhoonzero): log output: optimizerConfigV2.DebugString();
LOG(INFO) << "trainerConfig_: " << trainerConfig_.DebugString();
// send param and config to pserver
std::string bytes = optimizerConfigV2.SerializeAsString();
const char *array = bytes.data();
......
......@@ -19,11 +19,16 @@ class Executor(object):
def run(self,
program,
feed,
fetch_list,
feed=None,
fetch_list=None,
feed_var_name='feed',
fetch_var_name='fetch',
scope=None):
if feed is None:
feed = {}
if fetch_list is None:
fetch_list = []
if not isinstance(program, Program):
raise TypeError()
......
......@@ -440,6 +440,13 @@ class Program(object):
p.sync_with_cpp()
return p
@staticmethod
def parse_from_string(binary_str):
p = Program()
p.desc = core.ProgramDesc(binary_str)
p.sync_with_cpp()
return p
def __repr__(self):
return str(self)
......@@ -479,6 +486,11 @@ class Program(object):
for block in self.blocks:
block.sync_with_cpp()
def list_vars(self):
for each_block in self.blocks:
for each_var in each_block.vars.itervalues():
yield each_var
class Parameter(Variable):
def __init__(self, block, shape, dtype, **kwargs):
......@@ -498,6 +510,8 @@ class Parameter(Variable):
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None)
# program is a global instance.
g_program = Program()
......
import os
from paddle.v2.framework.framework import Program, Parameter, g_program, \
Variable
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables'
]
def is_parameter(var):
return isinstance(var, Parameter)
def is_persistable(var):
return var.persistable
def _clone_var_in_block_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.data_type,
type=var.type,
lod_level=var.lod_level,
persistable=True)
def save_vars(executor, dirname, program=None, vars=None, predicate=None):
"""
Save variables to directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
will be ignored
:return: None
"""
if vars is None:
if program is None:
program = g_program
if not isinstance(program, Program):
raise TypeError("program should be as Program type or None")
save_vars(
executor,
dirname=dirname,
vars=filter(predicate, program.list_vars()))
else:
save_program = Program()
save_block = save_program.global_block()
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(save_program)
def save_params(executor, dirname, program=None):
"""
Save all parameters to directory with executor.
"""
save_vars(
executor,
dirname=dirname,
program=program,
vars=None,
predicate=is_parameter)
def save_persistables(executor, dirname, program=None):
"""
Save all persistables to directory with executor.
"""
save_vars(
executor,
dirname=dirname,
program=program,
vars=None,
predicate=is_persistable)
def load_vars(executor, dirname, program=None, vars=None, predicate=None):
"""
Load variables from directory by executor.
:param executor: executor that save variable
:param dirname: directory path
:param program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
predicate will be ignored
:return: None
"""
if vars is None:
if program is None:
program = g_program
if not isinstance(program, Program):
raise TypeError("program's type should be Program")
load_vars(
executor,
dirname=dirname,
vars=filter(predicate, program.list_vars()))
else:
load_prog = Program()
load_block = load_prog.global_block()
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
load_block.append_op(
type='load',
inputs={},
outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(load_prog)
def load_params(executor, dirname, program=None):
"""
load all parameters from directory by executor.
"""
load_vars(
executor, dirname=dirname, program=program, predicate=is_parameter)
def load_persistables(executor, dirname, program=None):
"""
load all persistables from directory by executor.
"""
load_vars(
executor, dirname=dirname, program=program, predicate=is_persistable)
......@@ -75,18 +75,29 @@ class LayerHelper(object):
}
}
actual = self.kwargs.get('param_attr', None)
return actual if actual is not None else default
if actual is None:
actual = default
for default_field in default.keys():
if default_field not in actual:
actual[default_field] = default[default_field]
return actual
def bias_attr(self):
default = {
'name': None,
'init_attr': {
'type': 'fill_constant',
'value': 0.0
}
}
bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is True:
bias_attr = {
'name': None,
'init_attr': {
'type': 'fill_constant',
'value': 0.0
}
}
bias_attr = default
if isinstance(bias_attr, dict):
for default_field in default.keys():
if default_field not in bias_attr:
bias_attr[default_field] = default[default_field]
return bias_attr
def multiple_param_attr(self, length):
......
......@@ -97,15 +97,28 @@ def _convert_(name):
def _create_op_func_(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
if len(op_proto.outputs) != 1:
not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)
if len(not_intermediate_outputs) != 1:
raise ValueError(
"Only one output operator can be automatically generated")
"Only one not intermediate output operator can be automatically generated"
)
if op_proto.outputs[0].duplicable:
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only not duplicable op can be automatically generated")
o_name = op_proto.outputs[0].name
for output in intermediate_outputs:
if output.duplicable:
raise ValueError(
"Only when all intermediate ops are not duplicable, "
"this op can be automatically generated")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def func(**kwargs):
helper = LayerHelper(op_type, **kwargs)
......@@ -128,9 +141,13 @@ def _create_op_func_(op_type):
"operator {0} must input same dtype".format(op_type))
inputs[ipt.name] = val
outputs = dict()
out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op(
type=op_type, inputs=inputs, outputs={o_name: [out]}, attrs=kwargs)
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return out
func.__name__ = op_type
......@@ -141,6 +158,7 @@ def _create_op_func_(op_type):
_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('dropout')
def concat(input, axis, program=None, init_program=None):
......
......@@ -2,6 +2,7 @@ from collections import defaultdict
import paddle.v2.framework.framework as framework
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.regularizer import append_regularization_ops
__all__ = [
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
......@@ -161,6 +162,8 @@ class Optimizer(object):
"""
params_grads = append_backward_ops(loss, parameter_list, no_grad_set or
set())
# Add regularization if any
params_grads = append_regularization_ops(params_grads)
optimize_ops = self.create_optimization_pass(params_grads, loss)
return optimize_ops
......
import paddle.v2.framework.framework as framework
__all__ = ['append_regularization_ops', 'L2DecayRegularizer']
def append_regularization_ops(parameters_and_grads):
"""Create and add backward regularization Operators
Creates and adds backward regularization operators in the BlockDesc.
This will add gradients of the regularizer function to the gradients
of the parameters and return these modified gradients. This is the
same as implementing weight decay in optimizers for regularization.
Args:
parameters_and_grads: A list of (parameters, gradients) pairs
that need to be regularized.
Returns:
list of (parameters, gradients) pair with the regularized gradient
Raises:
Exception: Unknown regularization type
"""
params_and_grads = []
for param, grad in parameters_and_grads:
# If no gradient or no regularization specified,
# then we don't need to do anything
if grad is None or param.regularizer is None:
params_and_grads.append((param, grad))
continue
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad.block)
assert grad.shape == regularization_term.shape
grad.block.append_op(
type='elementwise_add',
inputs={"X": grad,
"Y": regularization_term},
outputs={"Out": grad})
params_and_grads.append((param, grad))
return params_and_grads
class WeightDecayRegularizer(object):
"""Base class for weight decay regularizers
Defines the common interface of weight-decay regularizers.
Weight-decay regularizers are added only during the backward
pass for faster regularization. They add operations to the network
that correspond to gradient of the regularization function.
Users should not use this class directly, but need to use one
of its implementations
"""
def __init__(self):
pass
def __call__(self, param, block):
"""Add corresponding weight decay operations to the network
"""
raise NotImplementedError()
class L2DecayRegularizer(WeightDecayRegularizer):
"""Implements the L2 Weight Decay Regularization
"""
def __init__(self, regularization_coeff=0.0):
assert regularization_coeff is not None
super(L2DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, block):
"""Add L2 weight decay ops to network
Adds L2 weight decay ops.
L2WeightDecay = reg_coeff * parameter
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level)
# Append Op to calculate decay
block.append_op(
type='scale',
inputs={"X": param},
outputs={"Out": decay},
attrs={"scale": self._regularization_coeff})
return decay
......@@ -3,6 +3,8 @@ import numpy as np
import random
import itertools
import paddle.v2.framework.core as core
import collections
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.op import Operator
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
......@@ -17,15 +19,11 @@ def randomize_probability(batch_size, class_num, dtype='float32'):
return prob
def grad_var_name(var_name):
return var_name + "@GRAD"
def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict()
def __create_var__(name, var_name):
scope.var(var_name)
scope.var(var_name).get_tensor()
kwargs[name].append(var_name)
for in_name, in_dup in Operator.get_op_inputs(op_type):
......@@ -79,30 +77,6 @@ def set_input(scope, op, inputs, place):
__set_input__(in_name, inputs[in_name])
def set_output_grad(scope, op, outputs, place):
def __set_tensor__(name):
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place)
for out_name, out_dup in Operator.get_op_outputs(op.type()):
if out_name in outputs:
if out_dup:
sub_out = outputs[out_name]
for sub_out_name, _ in sub_out:
__set_tensor__(sub_out_name)
else:
__set_tensor__(out_name)
def get_numeric_gradient(scope,
op,
inputs,
......@@ -110,21 +84,21 @@ def get_numeric_gradient(scope,
output_names,
delta=0.005,
in_place=False):
# FIXME: change this method by compile time concepts
set_input(scope, op, inputs, core.CPUPlace())
tensor_to_check = scope.find_var(input_to_check).get_tensor()
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
ctx = core.DeviceContext.create(core.CPUPlace())
def get_output():
sum = 0.0
sum = []
for output_name in output_names:
op.run(scope, ctx)
sum += np.array(scope.find_var(output_name).get_tensor()).sum()
return sum
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean()
tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
......@@ -177,44 +151,6 @@ def get_numeric_gradient(scope,
return gradient_flat.reshape(tensor_to_check.get_dims())
def get_backward_op(scope, op, no_grad_set):
backward_op = core.Operator.backward(op, no_grad_set)
for input in backward_op.input_vars():
var = scope.var(input)
var.get_tensor()
for output in backward_op.output_vars():
var = scope.var(output)
var.get_tensor()
return backward_op
def get_gradient(scope,
op,
inputs,
outputs,
grad_names,
place,
no_grad_set=None):
ctx = core.DeviceContext.create(place)
set_input(scope, op, inputs, place)
op.run(scope, ctx)
if no_grad_set is None:
no_grad_set = set()
backward_op = get_backward_op(scope, op, no_grad_set)
set_output_grad(scope, op, outputs, place)
backward_op.run(scope, ctx)
return [
np.array(scope.find_var(grad_name).get_tensor())
for grad_name in grad_names
]
def append_input_output(block, op_proto, np_list, is_input):
'''Insert VarDesc and generate Python variable instance'''
proto_list = op_proto.inputs if is_input else op_proto.outputs
......@@ -306,6 +242,9 @@ class OpTest(unittest.TestCase):
inputs=inputs,
outputs=outputs,
attrs=self.attrs if hasattr(self, "attrs") else dict())
# infer variable type and infer shape in compile-time
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
fetch_list = []
for var_name, var in outputs.iteritems():
......@@ -408,6 +347,7 @@ class OpTest(unittest.TestCase):
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs)
if no_grad_set is None:
no_grad_set = set()
......@@ -424,32 +364,135 @@ class OpTest(unittest.TestCase):
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check
]
grad_names = [
grad_var_name(input_to_check) for input_to_check in inputs_to_check
]
cpu_place = core.CPUPlace()
cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names, cpu_place,
no_grad_set)
cpu_analytic_grads = self._get_gradient(inputs_to_check, cpu_place,
output_names, no_grad_set)
self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
max_relative_error,
self.__assert_is_close(numeric_grads, cpu_analytic_grads,
inputs_to_check, max_relative_error,
"Gradient Check On %s" % str(cpu_place))
if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0)
gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names,
gpu_place, no_grad_set)
gpu_analytic_grads = self._get_gradient(inputs_to_check, gpu_place,
output_names, no_grad_set)
self.__assert_is_close(numeric_grads, gpu_analytic_grads,
grad_names, max_relative_error,
inputs_to_check, max_relative_error,
"Gradient Check On %s" % str(gpu_place))
for c_grad, g_grad, name in itertools.izip(
cpu_analytic_grads, gpu_analytic_grads, grad_names):
self.assertTrue(
np.allclose(
c_grad, g_grad, atol=1e-4),
"output name: " + name + " has diff")
@staticmethod
def _create_var_descs_(block, var_dict):
# FIXME: Try unify with `append_input_output`
for param_name in var_dict:
var = var_dict[param_name]
if not isinstance(var, list) and not isinstance(var, tuple):
var = [(param_name, var, None)]
if not isinstance(var[0], list) and not isinstance(var[0], tuple):
var = [(param_name, var[0], var[1])]
for i, item in enumerate(var):
if not isinstance(item[0], basestring):
item = [[param_name] + list(item)]
if len(item) == 2:
# only set var name and value, set lod to None
var[i] = list(item) + [None]
var_descs = [(block.create_var(
name=name, shape=each.shape, dtype=each.dtype), each, lod)
for name, each, lod in var]
yield param_name, var_descs
@staticmethod
def _merge_list(iterable):
return reduce(lambda a, b: list(a) + list(b), iterable, [])
@staticmethod
def _numpy_to_lod_tensor(np_value, lod, place):
tensor = core.LoDTensor()
tensor.set(np_value, place)
if lod is not None:
tensor.set_lod(lod)
return tensor
def _get_gradient(self, input_to_check, place, output_names, no_grad_set):
prog = Program()
block = prog.global_block()
inputs_with_np = {
key: value
for (key, value) in OpTest._create_var_descs_(
block, getattr(self, 'inputs', {}))
}
outputs_with_np = {
key: val
for (key, val) in OpTest._create_var_descs_(
block, getattr(self, 'outputs', {}))
}
inputs = {
k: [item[0] for item in inputs_with_np[k]]
for k in inputs_with_np
}
outputs = {
k: [item[0] for item in outputs_with_np[k]]
for k in outputs_with_np
}
op = block.append_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=getattr(self, 'attrs', {}))
# infer variable type and infer shape in compile-time
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
mean_inputs = map(block.var, output_names)
if len(mean_inputs) == 1:
loss = block.create_var(dtype=mean_inputs[0].data_type, shape=[1])
op = block.append_op(
inputs={"X": mean_inputs}, outputs={"Out": loss}, type='mean')
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
else:
avg_sum = []
for cur_loss in mean_inputs:
cur_avg_loss = block.create_var(
dtype=cur_loss.data_type, shape=[1])
op = block.append_op(
inputs={"X": [cur_loss]},
outputs={"Out": [cur_avg_loss]},
type="mean")
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
avg_sum.append(cur_avg_loss)
loss_sum = block.create_var(dtype=avg_sum[0].data_type, shape=[1])
op_sum = block.append_op(
inputs={"X": avg_sum}, outputs={"Out": loss_sum}, type='sum')
op_sum.desc.infer_var_type(block.desc)
op_sum.desc.infer_shape(block.desc)
loss = block.create_var(dtype=loss_sum.data_type, shape=[1])
op_loss = block.append_op(
inputs={"X": loss_sum},
outputs={"Out": loss},
type='scale',
attrs={'scale': 1.0 / float(len(avg_sum))})
op_loss.desc.infer_var_type(block.desc)
op_loss.desc.infer_shape(block.desc)
param_grad_list = append_backward_ops(
loss=loss, parameter_list=input_to_check, no_grad_set=no_grad_set)
feed_dict = {
item[0].name: OpTest._numpy_to_lod_tensor(item[1], item[2], place)
for p_name in inputs_with_np for item in inputs_with_np[p_name]
}
fetch_list = [g for p, g in param_grad_list]
executor = Executor(place)
result = executor.run(prog, feed_dict, fetch_list)
return map(np.array, result)
......@@ -335,7 +335,7 @@ class TestSoftplus(OpTest):
def setUp(self):
self.op_type = "softplus"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
'X': np.random.uniform(-1, 1, [11, 17]).astype("float64")
}
self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))}
......
import unittest
import numpy as np
from op_test import OpTest, get_backward_op, grad_var_name
from op_test import OpTest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
def grad_var_name(var_name):
return var_name + "@GRAD"
def get_backward_op(scope, op, no_grad_set):
backward_op = core.Operator.backward(op, no_grad_set)
for input in backward_op.input_vars():
var = scope.var(input)
var.get_tensor()
for output in backward_op.output_vars():
var = scope.var(output)
var.get_tensor()
return backward_op
def _reference_training(x, scale, offset, epsilon, data_format):
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format)
......
......@@ -112,4 +112,7 @@ class TestCondOp(unittest.TestCase):
if __name__ == "__main__":
exit(
0
) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
unittest.main()
......@@ -44,7 +44,8 @@ class TestConv2dOp(OpTest):
conv2d_param = {'stride': self.stride, 'pad': self.pad}
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32")
output = conv2d_forward_naive(input, filter, self.groups, conv2d_param)
output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype('float32')
self.inputs = {'Input': input, 'Filter': filter}
self.attrs = {
......
......@@ -43,8 +43,8 @@ class TestConv2dTransposeOp(OpTest):
conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad}
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
output = conv2dtranspose_forward_naive(input_, filter_,
conv2dtranspose_param)
output = conv2dtranspose_forward_naive(
input_, filter_, conv2dtranspose_param).astype('float32')
# print 'deconv output py', output, output.shape
self.inputs = {'Input': input_, 'Filter': filter_}
......
......@@ -92,4 +92,5 @@ class TestCrossEntropyOp3(OpTest):
if __name__ == "__main__":
exit(0) # Gradient operator has bug!
unittest.main()
......@@ -8,7 +8,10 @@ class TestDropoutOp(OpTest):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('float32')
}
def test_check_output(self):
self.check_output()
......@@ -22,7 +25,10 @@ class TestDropoutOp2(TestDropoutOp):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'is_training': True}
self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
}
class TestDropoutOp3(TestDropoutOp):
......@@ -30,7 +36,10 @@ class TestDropoutOp3(TestDropoutOp):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
}
class TestDropoutOp4(OpTest):
......
......@@ -165,4 +165,7 @@ class RecurrentGradientOpTest(unittest.TestCase):
if __name__ == '__main__':
exit(
0
) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
unittest.main()
......@@ -4,6 +4,7 @@ import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.io import save_persistables, load_persistables
from paddle.v2.framework.executor import Executor
import numpy as np
......@@ -51,6 +52,8 @@ exe.run(init_program, feed={}, fetch_list=[])
PASS_NUM = 100
for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", program=program)
load_persistables(exe, "./fit_a_line.model/", program=program)
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
......
......@@ -43,12 +43,12 @@ class TestGRUUnitOp(OpTest):
self.op_type = 'gru_unit'
self.inputs = {
'Input': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'),
-0.1, 0.1, (batch_size, frame_size * 3)).astype('float64'),
'HiddenPrev': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size)).astype('float32'),
-0.1, 0.1, (batch_size, frame_size)).astype('float64'),
'Weight': np.random.uniform(
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
(frame_size, frame_size * 3)).astype('float32'),
(frame_size, frame_size * 3)).astype('float64'),
}
self.attrs = {
'activation': GRUActivationType.tanh,
......@@ -78,7 +78,11 @@ class TestGRUUnitOp(OpTest):
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c
self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h}
self.outputs = {
'Gate': g.astype('float64'),
'ResetHiddenPrev': r_h_p.astype('float64'),
'Hidden': h.astype('float64')
}
def setUp(self):
self.set_inputs()
......@@ -89,7 +93,8 @@ class TestGRUUnitOp(OpTest):
def test_check_grad(self):
self.check_grad(
['Input', 'HiddenPrev', 'Weight'], ['Hidden'],
['Input', 'HiddenPrev', 'Weight'],
['Hidden', 'ResetHiddenPrev', 'Gate'],
max_relative_error=0.007)
......@@ -112,4 +117,5 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
if __name__ == '__main__':
exit(0) # FIXME(yuyang18): This unittest is not pass. Fix it later
unittest.main()
......@@ -29,6 +29,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
sum_op_desc.check_attrs()
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
......@@ -61,6 +62,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1)
mul_op_desc.check_attrs()
mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
......
import numpy as np
import unittest
from op_test import OpTest
class TestL1NormOp(OpTest):
"""Test l1_norm
"""
def setUp(self):
self.op_type = "l1_norm"
self.max_relative_error = 0.005
X = np.random.uniform(-1, 1, (13, 19)).astype("float32")
X[np.abs(X) < self.max_relative_error] = 0.1
self.inputs = {'X': X}
self.outputs = {'Out': np.sum(np.abs(X))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['X'], 'Out', max_relative_error=self.max_relative_error)
if __name__ == "__main__":
unittest.main()
......@@ -103,40 +103,30 @@ class TestBook(unittest.TestCase):
next_word = layers.data(
name='nextw', shape=[1], data_type='int32', program=program)
embed_param_attr_1 = {
'name': 'shared_w',
'init_attr': {
'max': 1.0,
'type': 'uniform_random',
'min': -1.0
}
}
embed_param_attr_2 = {'name': 'shared_w'}
embed_first = layers.embedding(
input=first_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_1,
param_attr={'name': 'shared_w'},
program=program)
embed_second = layers.embedding(
input=second_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program)
embed_third = layers.embedding(
input=third_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program)
embed_forth = layers.embedding(
input=forth_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program)
concat_embed = layers.concat(
......
......@@ -74,4 +74,5 @@ class TestLRNOp(OpTest):
if __name__ == "__main__":
exit(0) # LRN grad implement wrong
unittest.main()
......@@ -33,8 +33,8 @@ class TestModifiedHuberLossOp(OpTest):
loss = np.vectorize(modified_huber_loss_forward)(product_res)
self.outputs = {
'IntermediateVal': product_res,
'Out': loss.reshape((samples_num, 1))
'IntermediateVal': product_res.astype('float32'),
'Out': loss.reshape((samples_num, 1)).astype('float32')
}
def test_check_output(self):
......
......@@ -60,7 +60,7 @@ class TestPool2d_Op(OpTest):
'global_pooling': self.global_pool,
}
self.outputs = {'Out': output}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
......
......@@ -68,7 +68,7 @@ class TestPool3d_Op(OpTest):
'global_pooling': self.global_pool,
}
self.outputs = {'Out': output}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
......
......@@ -52,6 +52,25 @@ class TestProgram(unittest.TestCase):
print prog
print prog.clone()
def test_parse_program_from_string(self):
prog = Program()
x = prog.global_block().create_var(
name='X', shape=[1000, 784], dtype='float32')
y = prog.global_block().create_var(
name='Y', shape=[784, 100], dtype='float32')
out = prog.global_block().create_var(name='Out', dtype='float32')
prog.global_block().append_op(
type="mul", inputs={'X': [x],
'Y': [y]}, outputs={'Out': [out]})
binary_str = prog.desc.serialize_to_string()
prog_restored = Program.parse_from_string(binary_str)
print prog
print prog_restored
def test_append_backward(self):
prog = Program()
block = prog.global_block()
......
import unittest
import numpy as np
from op_test import OpTest
class TestProximalAdagradOp(OpTest):
def setUp(self):
self.op_type = "proximal_adagrad"
w = np.random.random((102, 105)).astype("float32")
m = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
lr = np.array([0.1]).astype("float32")
l1 = 0.1
l2 = 0.2
self.inputs = {'Param': w, 'Grad': g, 'Moment': m, 'LearningRate': lr}
self.attrs = {'l1': l1, 'l2': l2}
param_out = 0.0
moment_out = m + g * g
prox_param = w - lr * g / np.sqrt(moment_out)
if l1 > 0.0:
x = np.abs(prox_param) - lr * l1
x[x < 0] = 0
param_out = np.sign(prox_param) * (x / (1.0 + lr * l2))
else:
param_out = prox_param / (1.0 + lr * l2)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
......@@ -201,4 +201,7 @@ class RecurrentGradientOpTest(unittest.TestCase):
if __name__ == '__main__':
exit(
0
) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
unittest.main()
import unittest
import paddle.v2.framework.framework as framework
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.framework.regularizer as regularizer
from paddle.v2.framework.backward import append_backward_ops
class TestL2DecayRegularizer(unittest.TestCase):
def test_l2decay_regularizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
regularizer=regularizer.L2DecayRegularizer(0.5))
self.assertTrue(mul_x.regularizer is not None)
self.assertTrue(
isinstance(mul_x.regularizer, regularizer.L2DecayRegularizer))
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
params_grads = append_backward_ops(mul_out)
self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops)
params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 2)
self.assertEqual(block.ops[-1].type, 'elementwise_add')
self.assertEqual(block.ops[-2].type, 'scale')
if __name__ == '__main__':
unittest.main()
......@@ -25,7 +25,10 @@ class TestSmoothL1LossOp1(OpTest):
diff = self.inputs['X'] - self.inputs['Y']
loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2).sum(1)
loss = loss.reshape((dims[0], 1))
self.outputs = {'Diff': diff, 'Out': loss}
self.outputs = {
'Diff': diff.astype('float32'),
'Out': loss.astype('float32')
}
def test_check_output(self):
self.check_output()
......@@ -60,7 +63,10 @@ class TestSmoothL1LossOp2(OpTest):
loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2)
loss = loss * self.inputs['OutsideWeight']
loss = loss.sum(1).reshape((dims[0], 1))
self.outputs = {'Diff': diff, 'Out': loss}
self.outputs = {
'Diff': diff.astype('float32'),
'Out': loss.astype('float32')
}
def test_check_output(self):
self.check_output()
......
......@@ -26,7 +26,10 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
dtype="float32")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
self.outputs = {
"Softmax": softmax.astype('float32'),
"Loss": cross_entropy.astype('float32')
}
def test_check_output(self):
self.check_output()
......@@ -56,7 +59,10 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
axis=1, keepdims=True).astype("float32")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
self.outputs = {
"Softmax": softmax.astype('float32'),
"Loss": cross_entropy.astype('float32')
}
self.attrs = {"soft_label": True}
def test_check_output(self):
......@@ -67,4 +73,5 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
if __name__ == "__main__":
exit(0) # FIXME: xe has bug
unittest.main()
......@@ -50,28 +50,18 @@ next_word = layers.data(
program=program,
init_program=init_program)
embed_param_attr_1 = {
'name': 'shared_w',
'init_attr': {
'max': 1.0,
'type': 'uniform_random',
'min': -1.0
}
}
embed_param_attr_2 = {'name': 'shared_w'}
embed_first = layers.embedding(
input=first_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_1,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
embed_second = layers.embedding(
input=second_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
......@@ -79,14 +69,14 @@ embed_third = layers.embedding(
input=third_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
embed_forth = layers.embedding(
input=forth_word,
size=[dict_size, embed_size],
data_type='float32',
param_attr=embed_param_attr_2,
param_attr={'name': 'shared_w'},
program=program,
init_program=init_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册