OpType和Operator
OpType采用枚举类enum class,和传统的枚举enum相比,可以降低命名空间的污染、避免发生隐式转换。
同时,定义了一个Operator作为父类,其派生类中只存放相关的参数以及修改参数的方法,不包含计算的实现。
explicit构造函数是用来防止隐式转换的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| enum class OpType { kOperatorUnknown = -1, kOperatorRelu = 0, };
class Operator { public: explicit Operator(OpType op_type); virtual ~Operator() = default;
public: OpType kOpType = OpType::kOperatorUnknown; };
|
ReluOperator
ReluOperator参数只有thresh,同时也定义了两个成员函数用来查看和修改thresh。
1 2 3 4 5 6 7 8 9 10 11 12 13
| class ReluOperator : public Operator { public: explicit ReluOperator(float thresh);
~ReluOperator() override = default;
void set_thresh(float thresh);
float get_thresh() const;
private: float thresh_ = 0.f; };
|
1 2 3 4 5 6 7 8 9 10 11
| ReluOperator::ReluOperator(float thresh) : thresh_(thresh), Operator(OpType::kOperatorRelu) {
}
void ReluOperator::set_thresh(float thresh) { this->thresh_ = thresh; }
float ReluOperator::get_thresh() const { return thresh_; }
|
Layer
定义了一个Layer作为父类,其派生类中负责具体计算的实现。
其中,Layer的Forwards方法是具体的执行函数,负责将输入的inputs中的数据,进行运算并存放到对应的outputs中。
1 2 3 4 5 6 7 8 9 10 11 12 13
| class Layer { public: explicit Layer(const std::string &layer_name);
virtual ~Layer() = default;
virtual void Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs);
private: std::string layer_name_; };
|
ReluLayer
ReluOperator负责存放计算图中的参数信息,不负责计算,而ReluLayer则负责具体的计算操作,
因而,实现了属性存储和运算过程的分离。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| class ReluLayer : public Layer { public: explicit ReluLayer(const std::shared_ptr<Operator> &op);
~ReluLayer() override = default;
void Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs) override;
static std::shared_ptr<Layer> CreateInstance(const std::shared_ptr<Operator> &op);
private: std::shared_ptr<ReluOperator> op_; };
|
ReluLayer的构造函数,通过初始化列表:Layer("Relu"),调用父类Layer的构造函数来初始化属性layer_name_。
dynamic_cast是什么意思? 就是判断一下op指针是不是指向一个relu_op类的指针;
这边的op不是ReluOperator类型的指针,就报错;
我们这里只接受ReluOperator类型的指针;
父类指针必须指向子类ReluOperator类型的指针;
op.get()获得 shared_ptr 对象内部包含的普通指针;
为什么不讲构造函数设置为const std::shared_ptr &op?
为了接口统一,具体下节会说到。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| ReluLayer::ReluLayer(const std::shared_ptr<Operator> &op) : Layer("Relu") { CHECK(op->op_type_ == OpType::kOperatorRelu) << "Operator has a wrong type: " << int(op->op_type_); ReluOperator *relu_op = dynamic_cast<ReluOperator *>(op.get()); CHECK(relu_op != nullptr) << "Relu operator is empty";
this->op_ = std::make_unique<ReluOperator>(relu_op->get_thresh()); }
void ReluLayer::Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs) { CHECK(this->op_ != nullptr); CHECK(this->op_->op_type_ == OpType::kOperatorRelu);
const uint32_t batch_size = inputs.size(); for (int i = 0; i < batch_size; ++i) { CHECK(!inputs.at(i)->empty());
const std::shared_ptr<Tensor<float>> &input_data = inputs.at(i);
input_data->data().transform([&](float value) { float thresh = op_->get_thresh(); if (value >= thresh) { return value; } else { return 0.f; } });
outputs.push_back(input_data); } }
|
在Forwards中,首先读取输入input_data, 再对input_data使用armadillo自带的transform,
按照我们给定的thresh过滤其中的元素,如果value的值大于thresh则不变,如果小于thresh就返回0。
使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| TEST(test_layer, forward_relu1) { using namespace kuiper_infer;
std::shared_ptr<Tensor<float>> input = std::make_shared<Tensor<float>>(1, 1, 3); input->index(0) = -1.f; input->index(1) = -2.f; input->index(2) = 3.f; std::vector<std::shared_ptr<Tensor<float>>> inputs; inputs.push_back(input);
std::vector<std::shared_ptr<Tensor<float>>> outputs;
float thresh = 0.f; std::shared_ptr<Operator> relu_op = std::make_shared<ReluOperator>(thresh); ReluLayer layer(relu_op);
layer.Forwards(inputs, outputs); ASSERT_EQ(outputs.size(), 1);
for (int i = 0; i < outputs.size(); ++i) { ASSERT_EQ(outputs.at(i)->index(0), 0.f); ASSERT_EQ(outputs.at(i)->index(1), 0.f); ASSERT_EQ(outputs.at(i)->index(2), 3.f); } }
|