提交 376c2f01 编写于 作者: P phlrain

add default attr; test=develop

上级 0f1e7e3d
......@@ -208,7 +208,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
explicit AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs = {} )
: attrs_(attrs), default_attrs_(default_attrs) {}
template <typename T>
inline const T& Get(const std::string& name) const {
......@@ -224,6 +225,7 @@ class AttrReader {
private:
const AttributeMap& attrs_;
const AttributeMap& default_attrs_;
};
// check whether a value(attribute) fit a certain limit
......@@ -406,6 +408,14 @@ class OpAttrChecker {
return default_values_map;
}
void InitDefaultMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_values_map_, true);
}
}
const AttributeMap& default_attr_map() const { return default_values_map_; }
void RecordExplicitCheckerNum() {
explicit_checker_num_ = attr_checkers_.size();
}
......@@ -413,6 +423,8 @@ class OpAttrChecker {
private:
std::vector<AttrChecker> attr_checkers_;
AttributeMap default_values_map_;
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
......
......@@ -195,10 +195,13 @@ void HogwildWorker::TrainFilesWithProfiler() {
void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1);
std::cerr << "1!!!!!" << std::endl;
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
int i = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
i++;
for (auto &op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
......@@ -215,6 +218,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars();
thread_scope_->DropKids();
}
std::cerr << "total bacth " << i << std::endl;
#if defined PADDLE_WITH_PSCORE
if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
......
......@@ -124,6 +124,7 @@ Scope* MultiTrainer::GetWorkerScope(int thread_id) {
void MultiTrainer::Run() {
VLOG(3) << "Going to run";
LOG(ERROR) << "multi run " << thread_num_ << "\t" << debug_;
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
......
......@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultMap();
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
......
......@@ -28,4 +28,6 @@ endif(NOT WIN32)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function)
cc_binary(tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler )
add_subdirectory(tests)
......@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
default_attrs_(default_attrs){}
std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
......@@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
return attrs_.count(name) != 0 || default_attrs_.count( name );
}
const framework::AttributeMap& Attrs() const override { return attrs_; }
......@@ -100,8 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);
bool find = ( it != attrs_.end() );
if( it == attrs_.end() )
{
it = default_attrs_.find( name );
find = ( it != default_attrs_.end() );
}
PADDLE_ENFORCE_NE(
it, attrs_.end(),
find, false,
platform::errors::NotFound("can not find [%s] in attrs", name));
return it->second;
......@@ -192,6 +200,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
};
} // namespace imperative
......
......@@ -104,10 +104,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
}
#endif
//auto *attr_checker = op_->Info().Checker();
// 1. get expected kernel key
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, {} )) ;
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
// 2. check if op[type] has kernel registered.
......@@ -172,8 +173,10 @@ static void PreparedOpRunImpl(
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
auto *attr_checker = op.Info().Checker();
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs));
attrs,
attr_checker->default_attr_map() ));
/**
* [ Why need handle complex gradient to real gradient? ]
......
......@@ -358,7 +358,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map);
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map, {});
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......
......@@ -149,12 +149,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}
}
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true);
}
NameVarBaseMap new_ins = ins;
if (enable_autocast_) {
VLOG(5) << "Auto mixed precision run operator: " << type;
......
......@@ -109,6 +109,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
/*
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
......@@ -116,6 +117,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::LibraryType::kMKLDNN);
}
#endif
*/
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -187,11 +187,13 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
x_d, N, H * W * D, C, stats);
}
/*
Tensor c_g_st;
auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
{2 * C + 1}, platform::CPUPlace());
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
*/
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = dev_ctx.nccl_comm();
......
......@@ -177,4 +177,4 @@ static inline void HandleViewBetweenInputAndOutput(
} // namespace paddle
// This include must be the last line
#include "paddle/fluid/pybind/op_function_impl.h"
#include "paddle/fluid/pybind/op_function_impl_new.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册