提交 61ebacbc 编写于 作者: Q Qiao Longfei 提交者: GitHub

use operator context and infer context (#3024)

* use operator context

* optimize code

* update net infershape

* update InferShape

* disable override InferShape(scope) in OperatorBase

* change InferShapeImpl to InferShape

* add template to OperatorContext Input/Output

* merge Input InputVar, Output OutputVar

* change Inputs to MultiInput

* fix conflict

* fix MultiInput bugs and add unit test

* rename KernelContext to ExecutionContext

* clean code

* change InferShape to protected

* fix template bug

* refine code

* use InputVar instead of Input<Variable>

* typo

* optimize code
上级 0b680772
......@@ -16,8 +16,7 @@ static int run_cnt = 0;
class TestOp : public OperatorBase {
public:
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {
void InferShape(const std::shared_ptr<Scope>& scope) const override {
++infer_shape_cnt;
}
void Run(const std::shared_ptr<framework::Scope>& scope,
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
......@@ -28,7 +28,7 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice<
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif
......
......@@ -31,22 +31,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class OperatorBase;
class InferShapeContext;
class ExecutionContext;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -112,46 +99,127 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
};
class KernelContext {
class OperatorContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: op_(*op), scope_(scope) {}
size_t InputSize() const { return op_.inputs_.size(); }
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
size_t OutputSize() const { return op_.outputs_.size(); }
const Variable* InputVar(const size_t& index) const {
return scope_->GetVariable(op_.inputs_.at(index));
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
Variable* OutputVar(const size_t& index) const {
return scope_->GetVariable(op_.outputs_.at(index));
}
const Variable* Input(const std::string& name) const {
const Variable* InputVar(const std::string& name) const {
return scope_->GetVariable(op_.Input(name));
}
const Variable* Output(const std::string& name) const {
Variable* OutputVar(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}
const std::vector<const Variable*> Inputs(const std::string& name) const {
const std::vector<const Variable*> MultiInputVar(
const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), res.begin(),
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}
const std::vector<const Variable*> Outputs(const std::string& name) const {
std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), res.begin(),
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); });
return res;
}
template <typename T>
const T* Input(const size_t& index) const {
return &(InputVar(index)->Get<T>());
}
template <typename T>
T* Output(const size_t& index) const {
return OutputVar(index)->GetMutable<T>();
}
template <typename T>
const T* Input(const std::string& name) const {
return &(InputVar(name)->Get<T>());
}
template <typename T>
T* Output(const std::string& name) const {
return OutputVar(name)->GetMutable<T>();
}
template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return &scope_->GetVariable(name)->Get<T>();
});
return res;
}
template <typename T>
std::vector<const T*> MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return scope_->GetVariable(name)->GetMutable<T>();
});
return res;
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
};
class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: OperatorContext(op, scope) {}
};
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public OperatorContext {
public:
ExecutionContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: OperatorContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
......@@ -159,38 +227,23 @@ class KernelContext {
platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* ExecutionContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
* ExecutionContext. User should construct it before run the Operator.
*/
virtual void Compute(const KernelContext& context) const = 0;
virtual void Compute(const ExecutionContext& context) const = 0;
virtual ~OpKernel() {}
};
template <typename T>
struct VarToTensor {};
template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};
template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};
class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
......@@ -216,10 +269,14 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const std::shared_ptr<Scope>& scope) const {
InferShape(InferShapeContext(this, scope));
}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......@@ -228,34 +285,8 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels;
}
void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};
private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}
protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
virtual void InferShape(const InferShapeContext& ctx) const = 0;
};
} // namespace framework
......
......@@ -24,7 +24,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
op_run_num++;
......@@ -73,6 +74,7 @@ TEST(OperatorBase, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1);
}
......@@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
template <typename T1, typename T2>
class CPUKernelTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl;
cpu_kernel_run_num++;
......@@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
......@@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class CPUKernalMultiInputsTest : public OpKernel {
public:
void Compute(const KernelContext& ctx) const {
void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2");
auto inVar0 = ctx.MultiInputVar("xs");
ASSERT_EQ(inVar0.size(), 3);
auto intVar1 = ctx.InputVar("k");
ASSERT_NE(intVar1, nullptr);
auto outVar0 = ctx.MultiOutputVar("ys");
ASSERT_EQ(outVar0.size(), 2);
auto inTensor0 = ctx.MultiInput<Tensor>("xs");
ASSERT_EQ(inTensor0.size(), 3);
auto intTensor1 = ctx.Input<Tensor>("k");
ASSERT_NE(intTensor1, nullptr);
auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2);
auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0");
......@@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
scope->CreateVariable("x0")->GetMutable<Tensor>();
scope->CreateVariable("x1")->GetMutable<Tensor>();
scope->CreateVariable("x2")->GetMutable<Tensor>();
scope->CreateVariable("k0")->GetMutable<Tensor>();
scope->CreateVariable("y0")->GetMutable<Tensor>();
scope->CreateVariable("y1")->GetMutable<Tensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
......
......@@ -19,16 +19,16 @@ namespace operators {
class AddOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of Add Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims());
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
}
};
......@@ -49,8 +49,7 @@ The equation is: Out = X + Y
class AddOpGrad : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
......
......@@ -21,16 +21,17 @@ namespace operators {
template <typename Place, typename T>
class AddKernel : public OpKernel {
public:
void Compute(const KernelContext& context) const override {
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto output = context.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>(0);
auto input1 = context.Input<Tensor>(1);
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1);
framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
}
};
......
......@@ -19,20 +19,20 @@ namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2,
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1,
PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr,
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(outputs[0] != nullptr,
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
"X's dimension must be 2.");
PADDLE_ENFORCE(ctx.Output<Tensor>(0)->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->Resize({inputs[0]->dims()[0]});
ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
}
};
......
......@@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const KernelContext& context) const override {
auto X = context.Input(0)->Get<Tensor>();
const T* X_data = X.data<T>();
const int* label_data = context.Input(1)->Get<Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>(0);
const T* X_data = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>();
auto Y = ctx.Output<Tensor>(0);
Y->mutable_data<T>(context.GetPlace());
Y->mutable_data<T>(ctx.GetPlace());
T* Y_data = Y->data<T>();
int batch_size = X.dims()[0];
int class_num = X.dims()[1];
int batch_size = X->dims()[0];
int class_num = X->dims()[1];
// Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) {
......
......@@ -19,18 +19,17 @@ namespace operators {
class MulOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs");
auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims();
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2,
"The input of mul op must be matrix");
PADDLE_ENFORCE(
dim0[1] == dim1[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output");
outputs[0]->Resize({dim0[0], dim1[1]});
PADDLE_ENFORCE(ctx.OutputSize() == 1, "The mul op must take one output");
ctx.Output<Tensor>(0)->Resize({dim0[0], dim1[1]});
}
};
......@@ -51,8 +50,7 @@ The equation is: Out = X * Y
class MulOpGrad : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
......
......@@ -22,19 +22,17 @@ namespace operators {
template <typename Place, typename T>
class MulKernel : public OpKernel {
public:
void Compute(const KernelContext& context) const override {
void Compute(const ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1),
dim_pair);
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
}
};
} // namespace operators
......
......@@ -18,17 +18,17 @@ namespace operators {
class RowWiseAddOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add");
auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims();
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL,
"Two inputs is needed by rowwise add");
auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1");
outputs[0]->Resize(inputs[0]->dims());
PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
}
};
......
......@@ -21,14 +21,12 @@ namespace operators {
template <typename Place, typename T>
class RowWiseAddKernel : public OpKernel {
public:
void Compute(const KernelContext& context) const override {
auto in0 = context.Input(0)->Get<Tensor>();
auto in1 = context.Input(1)->Get<Tensor>();
auto* out = context.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& context) const override {
auto out = context.Output<Tensor>(0);
out->mutable_data<T>(context.GetPlace());
auto input = EigenMatrix<T>::From(in0);
auto bias = EigenVector<T>::From(in1);
auto input = EigenMatrix<T>::From(*context.Input<Tensor>(0));
auto bias = EigenVector<T>::From(*context.Input<Tensor>(1));
auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0);
......
......@@ -19,16 +19,15 @@ namespace operators {
class SGDOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of SGD Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims());
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
}
};
......
......@@ -21,16 +21,16 @@ namespace operators {
template <typename Place, typename T>
class SGDOpKernel : public OpKernel {
public:
void Compute(const KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<Tensor>();
auto grad = ctx.Input("grad")->Get<Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>(0);
float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(param) - lr * EigenVector<T>::Flatten(grad);
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
}
};
......
......@@ -18,11 +18,10 @@ namespace operators {
class SigmoidOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output");
outputs[0]->Resize(inputs[0]->dims());
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
}
};
......@@ -38,8 +37,7 @@ public:
class SigmoidOpGrad : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad";
return "";
......
......@@ -22,15 +22,14 @@ namespace operators {
template <typename Place, typename T>
class SigmoidKernel : public OpKernel {
public:
void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(input)).exp());
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
}
};
} // namespace operators
......
......@@ -18,14 +18,13 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2,
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
"The input of softmax op must be matrix");
PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax");
outputs[0]->Resize(inputs[0]->dims());
PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Only one output is need for softmax");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
}
};
......@@ -41,8 +40,7 @@ public:
class SoftmaxOpGrad : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad";
return "";
......
......@@ -22,12 +22,12 @@ namespace operators {
template <typename Place, typename T>
class SoftmaxKernel : public OpKernel {
public:
void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
auto logits = EigenMatrix<T>::From(input);
auto logits = EigenMatrix<T>::From(*input);
auto softmax = EigenMatrix<T>::From(*output);
const int kBatchDim = 0;
......
......@@ -22,7 +22,9 @@ namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext;
using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册