提交 376c2f01 编写于 作者: P phlrain

add default attr; test=develop

上级 0f1e7e3d
...@@ -208,7 +208,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); ...@@ -208,7 +208,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} explicit AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs = {} )
: attrs_(attrs), default_attrs_(default_attrs) {}
template <typename T> template <typename T>
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
...@@ -224,6 +225,7 @@ class AttrReader { ...@@ -224,6 +225,7 @@ class AttrReader {
private: private:
const AttributeMap& attrs_; const AttributeMap& attrs_;
const AttributeMap& default_attrs_;
}; };
// check whether a value(attribute) fit a certain limit // check whether a value(attribute) fit a certain limit
...@@ -406,6 +408,14 @@ class OpAttrChecker { ...@@ -406,6 +408,14 @@ class OpAttrChecker {
return default_values_map; return default_values_map;
} }
void InitDefaultMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_values_map_, true);
}
}
const AttributeMap& default_attr_map() const { return default_values_map_; }
void RecordExplicitCheckerNum() { void RecordExplicitCheckerNum() {
explicit_checker_num_ = attr_checkers_.size(); explicit_checker_num_ = attr_checkers_.size();
} }
...@@ -413,6 +423,8 @@ class OpAttrChecker { ...@@ -413,6 +423,8 @@ class OpAttrChecker {
private: private:
std::vector<AttrChecker> attr_checkers_; std::vector<AttrChecker> attr_checkers_;
AttributeMap default_values_map_;
// in order to improve the efficiency of dynamic graph mode, // in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type. // we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized // for explicit attribute, we mean the attribute added in the customized
......
...@@ -194,11 +194,14 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -194,11 +194,14 @@ void HogwildWorker::TrainFilesWithProfiler() {
void HogwildWorker::TrainFiles() { void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
std::cerr << "1!!!!!" << std::endl;
// how to accumulate fetched values here // how to accumulate fetched values here
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int i = 0;
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
i++;
for (auto &op : ops_) { for (auto &op : ops_) {
bool need_skip = false; bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) { for (auto t = 0u; t < skip_ops_.size(); ++t) {
...@@ -215,6 +218,7 @@ void HogwildWorker::TrainFiles() { ...@@ -215,6 +218,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
std::cerr << "total bacth " << i << std::endl;
#if defined PADDLE_WITH_PSCORE #if defined PADDLE_WITH_PSCORE
if (thread_barrier_) { if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
......
...@@ -124,6 +124,7 @@ Scope* MultiTrainer::GetWorkerScope(int thread_id) { ...@@ -124,6 +124,7 @@ Scope* MultiTrainer::GetWorkerScope(int thread_id) {
void MultiTrainer::Run() { void MultiTrainer::Run() {
VLOG(3) << "Going to run"; VLOG(3) << "Going to run";
LOG(ERROR) << "multi run " << thread_num_ << "\t" << debug_;
for (int thidx = 0; thidx < thread_num_; ++thidx) { for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) { if (!debug_) {
threads_.push_back( threads_.push_back(
......
...@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker; op_checker_ = attr_checker;
Make(); Make();
op_checker_->RecordExplicitCheckerNum(); op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultMap();
AddAttr<int>(OpRoleAttrName(), "The role of this operator") AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum( .InEnum(
......
...@@ -28,4 +28,6 @@ endif(NOT WIN32) ...@@ -28,4 +28,6 @@ endif(NOT WIN32)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function)
cc_binary(tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler )
add_subdirectory(tests) add_subdirectory(tests)
...@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in, const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out, const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx), : ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in), var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out), var_base_map_out_(var_base_map_out),
attrs_(attrs) {} attrs_(attrs),
default_attrs_(default_attrs){}
std::string InputName(const std::string& name) const override { std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_base_map_in_.find(name);
...@@ -92,16 +94,22 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -92,16 +94,22 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0; return attrs_.count(name) != 0 || default_attrs_.count( name );
} }
const framework::AttributeMap& Attrs() const override { return attrs_; } const framework::AttributeMap& Attrs() const override { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const override { const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name); auto it = attrs_.find(name);
bool find = ( it != attrs_.end() );
if( it == attrs_.end() )
{
it = default_attrs_.find( name );
find = ( it != default_attrs_.end() );
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, attrs_.end(), find, false,
platform::errors::NotFound("can not find [%s] in attrs", name)); platform::errors::NotFound("can not find [%s] in attrs", name));
return it->second; return it->second;
...@@ -192,6 +200,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -192,6 +200,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_; const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_; const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -104,10 +104,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -104,10 +104,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
} }
#endif #endif
//auto *attr_checker = op_->Info().Checker();
// 1. get expected kernel key // 1. get expected kernel key
auto expected_kernel_key = auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>( op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, {} )) ;
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
// 2. check if op[type] has kernel registered. // 2. check if op[type] has kernel registered.
...@@ -172,8 +173,10 @@ static void PreparedOpRunImpl( ...@@ -172,8 +173,10 @@ static void PreparedOpRunImpl(
static_cast<const framework::OperatorWithKernel&>(op).InferShape( static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx); &infer_shape_ctx);
auto *attr_checker = op.Info().Checker();
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs)); attrs,
attr_checker->default_attr_map() ));
/** /**
* [ Why need handle complex gradient to real gradient? ] * [ Why need handle complex gradient to real gradient? ]
......
...@@ -358,7 +358,7 @@ TEST(test_layer, test_dygraph_execution_context) { ...@@ -358,7 +358,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope; framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context( DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map); *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map, {});
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......
...@@ -149,11 +149,16 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -149,11 +149,16 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
} }
} }
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info(); const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker(); auto* attr_checker = op_info.Checker();
if (attr_checker) { if (attr_checker) {
attr_checker->Check(&attrs, true); attr_checker->Check(&attrs, true);
} }
NameVarBaseMap new_ins = ins; NameVarBaseMap new_ins = ins;
if (enable_autocast_) { if (enable_autocast_) {
......
...@@ -109,6 +109,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -109,6 +109,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
/*
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
...@@ -116,6 +117,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -116,6 +117,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
} }
#endif #endif
*/
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -186,12 +186,14 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx, ...@@ -186,12 +186,14 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>( framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
x_d, N, H * W * D, C, stats); x_d, N, H * W * D, C, stats);
} }
/*
Tensor c_g_st; Tensor c_g_st;
auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>( auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
{2 * C + 1}, platform::CPUPlace()); {2 * C + 1}, platform::CPUPlace());
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
*/
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = dev_ctx.nccl_comm(); auto *comm = dev_ctx.nccl_comm();
......
...@@ -177,4 +177,4 @@ static inline void HandleViewBetweenInputAndOutput( ...@@ -177,4 +177,4 @@ static inline void HandleViewBetweenInputAndOutput(
} // namespace paddle } // namespace paddle
// This include must be the last line // This include must be the last line
#include "paddle/fluid/pybind/op_function_impl.h" #include "paddle/fluid/pybind/op_function_impl_new.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册