提交 61ebacbc 编写于 作者: Q Qiao Longfei 提交者: GitHub

use operator context and infer context (#3024)

* use operator context

* optimize code

* update net infershape

* update InferShape

* disable override InferShape(scope) in OperatorBase

* change InferShapeImpl to InferShape

* add template to OperatorContext Input/Output

* merge Input InputVar, Output OutputVar

* change Inputs to MultiInput

* fix conflict

* fix MultiInput bugs and add unit test

* rename KernelContext to ExecutionContext

* clean code

* change InferShape to protected

* fix template bug

* refine code

* use InputVar instead of Input<Variable>

* typo

* optimize code
上级 0b680772
...@@ -16,8 +16,7 @@ static int run_cnt = 0; ...@@ -16,8 +16,7 @@ static int run_cnt = 0;
class TestOp : public OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape( void InferShape(const std::shared_ptr<Scope>& scope) const override {
const std::shared_ptr<framework::Scope>& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
void Run(const std::shared_ptr<framework::Scope>& scope, void Run(const std::shared_ptr<framework::Scope>& scope,
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
template <> template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice< Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>(); return device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
...@@ -28,7 +28,7 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice< ...@@ -28,7 +28,7 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice<
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>(); return device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -31,22 +31,9 @@ limitations under the License. */ ...@@ -31,22 +31,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class OperatorBase; class OperatorBase;
class InferShapeContext;
class ExecutionContext;
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -112,46 +99,127 @@ class OperatorBase { ...@@ -112,46 +99,127 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
}; };
class KernelContext { class OperatorContext {
public: public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
const platform::DeviceContext& device_context) : op_(*op), scope_(scope) {}
: op_(*op), scope_(scope), device_context_(device_context) {}
size_t InputSize() const { return op_.inputs_.size(); }
const Variable* Input(int index) const { size_t OutputSize() const { return op_.outputs_.size(); }
return scope_->GetVariable(op_.inputs_[index]);
const Variable* InputVar(const size_t& index) const {
return scope_->GetVariable(op_.inputs_.at(index));
} }
Variable* Output(int index) const { Variable* OutputVar(const size_t& index) const {
return scope_->GetVariable(op_.outputs_[index]); return scope_->GetVariable(op_.outputs_.at(index));
} }
const Variable* Input(const std::string& name) const { const Variable* InputVar(const std::string& name) const {
return scope_->GetVariable(op_.Input(name)); return scope_->GetVariable(op_.Input(name));
} }
const Variable* Output(const std::string& name) const { Variable* OutputVar(const std::string& name) const {
return scope_->GetVariable(op_.Output(name)); return scope_->GetVariable(op_.Output(name));
} }
const std::vector<const Variable*> Inputs(const std::string& name) const { const std::vector<const Variable*> MultiInputVar(
const std::string& name) const {
auto names = op_.Inputs(name); auto names = op_.Inputs(name);
std::vector<const Variable*> res; std::vector<const Variable*> res;
res.reserve(names.size());
std::transform( std::transform(
names.begin(), names.end(), res.begin(), names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); }); [this](const std::string& name) { return scope_->GetVariable(name); });
return res; return res;
} }
const std::vector<const Variable*> Outputs(const std::string& name) const { std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<const Variable*> res; std::vector<const Variable*> res;
res.reserve(names.size());
std::transform( std::transform(
names.begin(), names.end(), res.begin(), names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); }); [this](const std::string& name) { return scope_->GetVariable(name); });
return res; return res;
} }
template <typename T>
const T* Input(const size_t& index) const {
return &(InputVar(index)->Get<T>());
}
template <typename T>
T* Output(const size_t& index) const {
return OutputVar(index)->GetMutable<T>();
}
template <typename T>
const T* Input(const std::string& name) const {
return &(InputVar(name)->Get<T>());
}
template <typename T>
T* Output(const std::string& name) const {
return OutputVar(name)->GetMutable<T>();
}
template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return &scope_->GetVariable(name)->Get<T>();
});
return res;
}
template <typename T>
std::vector<const T*> MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return scope_->GetVariable(name)->GetMutable<T>();
});
return res;
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
};
class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: OperatorContext(op, scope) {}
};
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public OperatorContext {
public:
ExecutionContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: OperatorContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
...@@ -159,38 +227,23 @@ class KernelContext { ...@@ -159,38 +227,23 @@ class KernelContext {
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
class OpKernel { class OpKernel {
public: public:
/** /**
* KernelContext is the only parameter of Kernel Run function. * ExecutionContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and * Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from * device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator. * ExecutionContext. User should construct it before run the Operator.
*/ */
virtual void Compute(const KernelContext& context) const = 0; virtual void Compute(const ExecutionContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernel() {}
}; };
template <typename T>
struct VarToTensor {};
template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};
template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
struct OpKernelKey { struct OpKernelKey {
...@@ -216,10 +269,14 @@ class OperatorWithKernel : public OperatorBase { ...@@ -216,10 +269,14 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const std::shared_ptr<Scope>& scope) const {
InferShape(InferShapeContext(this, scope));
}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx)); opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
...@@ -228,34 +285,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -228,34 +285,8 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels; return g_all_op_kernels;
} }
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};
private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}
protected: protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs, virtual void InferShape(const InferShapeContext& ctx) const = 0;
const std::vector<Tensor*>& outputs) const = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -24,7 +24,8 @@ static int op_run_num = 0; ...@@ -24,7 +24,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
op_run_num++; op_run_num++;
...@@ -73,6 +74,7 @@ TEST(OperatorBase, all) { ...@@ -73,6 +74,7 @@ TEST(OperatorBase, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1"); scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context); op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1); ASSERT_EQ(paddle::framework::op_run_num, 1);
} }
...@@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0; ...@@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor*>& inputs, void InferShape(const framework::InferShapeContext& ctx) const override {}
const std::vector<Tensor*>& outputs) const override {}
}; };
template <typename T1, typename T2> template <typename T1, typename T2>
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl; std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl; std::cout << ctx.op_.DebugString() << std::endl;
cpu_kernel_run_num++; cpu_kernel_run_num++;
...@@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel { ...@@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase { class OperatorMultiInputsTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
...@@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class CPUKernalMultiInputsTest : public OpKernel { class CPUKernalMultiInputsTest : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op_.Inputs("xs"); auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1"); ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2"); ASSERT_EQ(xs[2], "x2");
auto inVar0 = ctx.MultiInputVar("xs");
ASSERT_EQ(inVar0.size(), 3);
auto intVar1 = ctx.InputVar("k");
ASSERT_NE(intVar1, nullptr);
auto outVar0 = ctx.MultiOutputVar("ys");
ASSERT_EQ(outVar0.size(), 2);
auto inTensor0 = ctx.MultiInput<Tensor>("xs");
ASSERT_EQ(inTensor0.size(), 3);
auto intTensor1 = ctx.Input<Tensor>("k");
ASSERT_NE(intTensor1, nullptr);
auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2);
auto k = ctx.op_.Input("k"); auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0"); ASSERT_EQ(k, "k0");
...@@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) { ...@@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
scope->CreateVariable("x0")->GetMutable<Tensor>();
scope->CreateVariable("x1")->GetMutable<Tensor>();
scope->CreateVariable("x2")->GetMutable<Tensor>();
scope->CreateVariable("k0")->GetMutable<Tensor>();
scope->CreateVariable("y0")->GetMutable<Tensor>();
scope->CreateVariable("y1")->GetMutable<Tensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
......
...@@ -19,16 +19,16 @@ namespace operators { ...@@ -19,16 +19,16 @@ namespace operators {
class AddOp : public OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
PADDLE_ENFORCE( "Inputs of AddOp must all be set");
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr, PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Inputs/Outputs of AddOp must all be set"); "Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of Add Op's dimension must be same."); "Two input of Add Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
...@@ -49,8 +49,7 @@ The equation is: Out = X + Y ...@@ -49,8 +49,7 @@ The equation is: Out = X + Y
class AddOpGrad : public OperatorWithKernel { class AddOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "AddOpGrad"; LOG(INFO) << "AddOpGrad";
return ""; return "";
......
...@@ -21,16 +21,17 @@ namespace operators { ...@@ -21,16 +21,17 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public OpKernel { class AddKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input0 = context.Input(0)->Get<Tensor>(); auto input0 = context.Input<Tensor>(0);
auto input1 = context.Input(1)->Get<Tensor>(); auto input1 = context.Input<Tensor>(1);
auto output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1); framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
} }
}; };
......
...@@ -19,20 +19,20 @@ namespace operators { ...@@ -19,20 +19,20 @@ namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2,
PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Output size of OnehotCrossEntropyOp must be one"); "Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr, PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of OnehotCrossEntropyOp must all be set"); "Inputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(outputs[0] != nullptr, PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of OnehotCrossEntropyOp must all be set"); "Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
PADDLE_ENFORCE(outputs[0]->dims().size() == 1, "X's dimension must be 2.");
PADDLE_ENFORCE(ctx.Output<Tensor>(0)->dims().size() == 1,
"label's dimension must be 1."); "label's dimension must be 1.");
outputs[0]->Resize({inputs[0]->dims()[0]}); ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
} }
}; };
......
...@@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel { ...@@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
public: public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); } constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& ctx) const override {
auto X = context.Input(0)->Get<Tensor>(); auto X = ctx.Input<Tensor>(0);
const T* X_data = X.data<T>(); const T* X_data = X->data<T>();
const int* label_data = context.Input(1)->Get<Tensor>().data<int>(); const int* label_data = ctx.Input<Tensor>(1)->data<int>();
auto* Y = context.Output(0)->GetMutable<Tensor>(); auto Y = ctx.Output<Tensor>(0);
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(ctx.GetPlace());
T* Y_data = Y->data<T>(); T* Y_data = Y->data<T>();
int batch_size = X.dims()[0]; int batch_size = X->dims()[0];
int class_num = X.dims()[1]; int class_num = X->dims()[1];
// Y[i] = -log(X[i][j]) // Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -19,18 +19,17 @@ namespace operators { ...@@ -19,18 +19,17 @@ namespace operators {
class MulOp : public OperatorWithKernel { class MulOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs"); auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim0 = inputs[0]->dims(); auto dim1 = ctx.Input<Tensor>(1)->dims();
auto dim1 = inputs[1]->dims();
PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2, PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2,
"The input of mul op must be matrix"); "The input of mul op must be matrix");
PADDLE_ENFORCE( PADDLE_ENFORCE(
dim0[1] == dim1[0], dim0[1] == dim1[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "The mul op must take one output");
outputs[0]->Resize({dim0[0], dim1[1]}); ctx.Output<Tensor>(0)->Resize({dim0[0], dim1[1]});
} }
}; };
...@@ -51,8 +50,7 @@ The equation is: Out = X * Y ...@@ -51,8 +50,7 @@ The equation is: Out = X * Y
class MulOpGrad : public OperatorWithKernel { class MulOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "MulGrad"; LOG(INFO) << "MulGrad";
return ""; return "";
......
...@@ -22,19 +22,17 @@ namespace operators { ...@@ -22,19 +22,17 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public OpKernel { class MulKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input(0)->Get<Tensor>(); auto output = context.Output<Tensor>(0);
auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1), EigenMatrix<T>::From(*context.Input<Tensor>("X"))
dim_pair); .contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -18,17 +18,17 @@ namespace operators { ...@@ -18,17 +18,17 @@ namespace operators {
class RowWiseAddOp : public OperatorWithKernel { class RowWiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2UL,
PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add"); "Two inputs is needed by rowwise add");
auto dim0 = inputs[0]->dims(); auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = inputs[1]->dims(); auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
......
...@@ -21,14 +21,12 @@ namespace operators { ...@@ -21,14 +21,12 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public OpKernel { class RowWiseAddKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto in0 = context.Input(0)->Get<Tensor>(); auto out = context.Output<Tensor>(0);
auto in1 = context.Input(1)->Get<Tensor>();
auto* out = context.Output(0)->GetMutable<Tensor>();
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto input = EigenMatrix<T>::From(in0); auto input = EigenMatrix<T>::From(*context.Input<Tensor>(0));
auto bias = EigenVector<T>::From(in1); auto bias = EigenVector<T>::From(*context.Input<Tensor>(1));
auto output = EigenMatrix<T>::From(*out); auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);
......
...@@ -19,16 +19,15 @@ namespace operators { ...@@ -19,16 +19,15 @@ namespace operators {
class SGDOp : public OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
......
...@@ -21,16 +21,16 @@ namespace operators { ...@@ -21,16 +21,16 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public OpKernel { class SGDOpKernel : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
auto param = ctx.Input("param")->Get<Tensor>(); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input("grad")->Get<Tensor>(); auto grad = ctx.Input<Tensor>("grad");
auto* param_out = ctx.Output(0)->GetMutable<Tensor>(); auto param_out = ctx.Output<Tensor>(0);
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(param) - lr * EigenVector<T>::Flatten(grad); EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
} }
}; };
......
...@@ -18,11 +18,10 @@ namespace operators { ...@@ -18,11 +18,10 @@ namespace operators {
class SigmoidOp : public OperatorWithKernel { class SigmoidOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
outputs[0]->Resize(inputs[0]->dims());
} }
}; };
...@@ -38,8 +37,7 @@ public: ...@@ -38,8 +37,7 @@ public:
class SigmoidOpGrad : public OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad"; LOG(INFO) << "SigmoidGrad";
return ""; return "";
......
...@@ -22,15 +22,14 @@ namespace operators { ...@@ -22,15 +22,14 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidKernel : public OpKernel { class SigmoidKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input(0)->Get<Tensor>(); auto input = context.Input<Tensor>(0);
auto* output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -18,14 +18,13 @@ namespace operators { ...@@ -18,14 +18,13 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
PADDLE_ENFORCE(inputs[0]->dims().size() == 2,
"The input of softmax op must be matrix"); "The input of softmax op must be matrix");
PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Only one output is need for softmax");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
...@@ -41,8 +40,7 @@ public: ...@@ -41,8 +40,7 @@ public:
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad"; LOG(INFO) << "SoftmaxOpGrad";
return ""; return "";
......
...@@ -22,12 +22,12 @@ namespace operators { ...@@ -22,12 +22,12 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public OpKernel { class SoftmaxKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input(0)->Get<Tensor>(); auto input = context.Input<Tensor>(0);
auto* output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto logits = EigenMatrix<T>::From(input); auto logits = EigenMatrix<T>::From(*input);
auto softmax = EigenMatrix<T>::From(*output); auto softmax = EigenMatrix<T>::From(*output);
const int kBatchDim = 0; const int kBatchDim = 0;
......
...@@ -22,7 +22,9 @@ namespace paddle { ...@@ -22,7 +22,9 @@ namespace paddle {
namespace operators { namespace operators {
using OpKernel = framework::OpKernel; using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext; using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
template <typename T, template <typename T,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册