本文以arma::cube为基础实现了Tensor类,提供了更方便的访问方式和对外接口。
Tensor类模板 C++类模板例子:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 template <typename T1, typename T2> class Complex {public : Complex (T1 a, T2 b) : _a(a), _b(b); Complex<T1, T2> operator +(Complex<T1, T2> &c); private : T1 _a; T2 _b; }; template <>class Complex <int , int > {public : Complex (int a, int b) : _a(a), _b(b); Complex<int , int > operator +(Complex<int , int > &c); private : int _a; int _b; }; Complex<int , int > c1 (1 ,2 ) ;
模板分为类模板与函数模板,特化分为全特化与偏特化(partial specialization)。
对于模板、模板的全特化和模板的偏特化, 以及同名普通函数都存在的情况下,编译器在编译阶段进行匹配时,只匹配普通函数和模板, 匹配顺序如下:
查找普通函数中有没有匹配的,如果有就选它
查找模板中有没有匹配的, 并选则最匹配的版本, 然后进行下面两步
注意, 上面规则没提到特化版本, 如果编译器匹配到了规则2, 然后才进行特化版本的匹配
查找全特化版本中有没有匹配的
查找偏特化版本中有没有匹配的
Tensor共有两个类型,一个类型是Tensor,另一个类型是Tensor, Tensor 可能会在后续的量化课程中进行使用,目前还暂时未实现. 我们把Tensor和Tensor全特化,如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 template <typename T>class Tensor {}; template <>class Tensor <uint8_t > { }; template <>class Tensor <float > { }
张量 const小结 常量引用
const修饰实参:表示不能改变实参的值
const成员函数:表示不能改变所有成员变量的值
常量引用:不能通过引用修改其所绑定的对象,但能以其它方式修改这个对象。
构造函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 explicit Tensor () = default ; explicit Tensor (uint32_t channels, uint32_t rows, uint32_t cols) ; Tensor (const Tensor &tensor); Tensor<float > &operator =(const Tensor &tensor); Tensor<float >::Tensor (uint32_t channels, uint32_t rows, uint32_t cols) { data_ = arma::fcube (rows, cols, channels); } Tensor<float >::Tensor (const Tensor &tensor) { this ->data_ = tensor.data_; this ->raw_shapes_ = tensor.raw_shapes_; } Tensor<float > &Tensor<float >::operator =(const Tensor &tensor) { if (this != &tensor) { this ->data_ = tensor.data_; this ->raw_shapes_ = tensor.raw_shapes_; } return *this ; }
张量的维度大小 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 uint32_t rows () const ; uint32_t cols () const ; uint32_t channels () const ; uint32_t size () const ; std::vector<uint32_t > shapes () const ; uint32_t Tensor<float >::rows () const { CHECK (!this ->data_.empty ()); return this ->data_.n_rows; } uint32_t Tensor<float >::cols () const { CHECK (!this ->data_.empty ()); return this ->data_.n_cols; } uint32_t Tensor<float >::channels () const { CHECK (!this ->data_.empty ()); return this ->data_.n_slices; } uint32_t Tensor<float >::size () const { CHECK (!this ->data_.empty ()); return this ->data_.size (); } std::vector<uint32_t > Tensor<float >::shapes () const { CHECK (!this ->data_.empty ()); return {this ->channels (), this ->rows (), this ->cols ()}; }
取数据和索引 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 arma::fcube &data () ; const arma::fcube &data () const ; arma::fmat &at (uint32_t channel) ; const arma::fmat &at (uint32_t channel) const ; float at (uint32_t channel, uint32_t row, uint32_t col) const ; float &at (uint32_t channel, uint32_t row, uint32_t col) ; float index (uint32_t offset) const ; arma::fcube &Tensor<float >::data () { return this ->data_; } const arma::fcube &Tensor<float >::data () const { return this ->data_; } arma::fmat &Tensor<float >::at (uint32_t channel) { CHECK_LT (channel, this ->channels ()); return this ->data_.slice (channel); } const arma::fmat &Tensor<float >::at (uint32_t channel) const { CHECK_LT (channel, this ->channels ()); return this ->data_.slice (channel); } float Tensor<float >::at (uint32_t channel, uint32_t row, uint32_t col) const { CHECK_LT (row, this ->rows ()); CHECK_LT (col, this ->cols ()); CHECK_LT (channel, this ->channels ()); return this ->data_.at (row, col, channel); } float &Tensor<float >::at (uint32_t channel, uint32_t row, uint32_t col) { CHECK_LT (row, this ->rows ()); CHECK_LT (col, this ->cols ()); CHECK_LT (channel, this ->channels ()); return this ->data_.at (row, col, channel); } float Tensor<float >::index (uint32_t offset) const { CHECK (offset < this ->data_.size ()); return this ->data_.at (offset); }
初始化张量 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 void set_data (const arma::fcube &data) ; void Ones () ; void Rand () ; void Tensor<float >::set_data (const arma::fcube &data) { CHECK (data.n_rows == this ->data_.n_rows) << data.n_rows << " != " << this ->data_.n_rows; CHECK (data.n_cols == this ->data_.n_cols) << data.n_cols << " != " << this ->data_.n_cols; CHECK (data.n_slices == this ->data_.n_slices) << data.n_slices << " != " << this ->data_.n_slices; this ->data_ = data; } void Tensor<float >::Rand () { CHECK (!this ->data_.empty ()); this ->data_.randn (); } void Tensor<float >::Ones () { CHECK (!this ->data_.empty ()); this ->data_.fill (1. ); }
张量填充 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 void Padding (const std::vector<uint32_t > &pads, float padding_value) ; void Fill (float value) ; void Fill (const std::vector<float > &values) ; void Tensor<float >::Padding (const std::vector<uint32_t > &pads, float padding_value) { CHECK (!this ->data_.empty ()); CHECK_EQ (pads.size (), 4 ); uint32_t pad_rows1 = pads.at (0 ); uint32_t pad_rows2 = pads.at (1 ); uint32_t pad_cols1 = pads.at (2 ); uint32_t pad_cols2 = pads.at (3 ); this ->data_.insert_rows (0 , pad_rows1); this ->data_.insert_rows (this ->data_.n_rows, pad_rows2); this ->data_.insert_cols (0 , pad_cols1); this ->data_.insert_cols (this ->data_.n_cols, pad_cols2); } void Tensor<float >::Fill (float value) { CHECK (!this ->data_.empty ()); this ->data_.fill (value); } void Tensor<float >::Fill (const std::vector<float > &values) { CHECK (!this ->data_.empty ()); const uint32_t total_elems = this ->data_.size (); CHECK_EQ (values.size (), total_elems); const uint32_t rows = this ->rows (); const uint32_t cols = this ->cols (); const uint32_t planes = rows * cols; const uint32_t channels = this ->data_.n_slices; for (uint32_t i = 0 ; i < channels; i++) { auto &channel_data = this ->data_.slice (i); const arma::fmat &channel_data_t = arma::fmat (values.data () + i * planes, this ->cols (), this ->rows ()); channel_data = channel_data_t .t (); } }
其他 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 bool empty () const ; void Show () ; void Flatten () ; bool Tensor<float >::empty () const { return this ->data_.empty (); } void Tensor<float >::Show () { for (uint32_t i = 0 ; i < this ->channels (); ++i) { LOG (INFO) << "Channel: " << i; LOG (INFO) << "\n" << this ->data_.slice (i); } } void Tensor<float >::Flatten () { CHECK (!this ->data_.empty ()); const uint32_t size = this ->data_.size (); arma::fcube linear_cube (size, 1 , 1 ) ; uint32_t channel = this ->channels (); uint32_t rows = this ->rows (); uint32_t cols = this ->cols (); uint32_t index = 0 ; for (uint32_t c = 0 ; c < channel; ++c) { const arma::fmat &matrix = this ->data_.slice (c); for (uint32_t r = 0 ; r < rows; ++r) { for (uint32_t c_ = 0 ; c_ < cols; ++c_) { linear_cube.at (index, 0 , 0 ) = matrix.at (r, c_); index += 1 ; } } } CHECK_EQ (index, size); this ->data_ = linear_cube; this ->raw_shapes_ = std::vector<uint32_t >{size}; }
完整的接口定义 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 template <>class Tensor <float > { public : explicit Tensor () = default ; explicit Tensor (uint32_t channels, uint32_t rows, uint32_t cols) ; Tensor (const Tensor &tensor); Tensor<float > &operator =(const Tensor &tensor); uint32_t rows () const ; uint32_t cols () const ; uint32_t channels () const ; uint32_t size () const ; void set_data (const arma::fcube &data) ; bool empty () const ; float index (uint32_t offset) const ; std::vector<uint32_t > shapes () const ; arma::fcube &data () ; const arma::fcube &data () const ; arma::fmat &at (uint32_t channel) ; const arma::fmat &at (uint32_t channel) const ; float at (uint32_t channel, uint32_t row, uint32_t col) const ; float &at (uint32_t channel, uint32_t row, uint32_t col) ; void Padding (const std::vector<uint32_t > &pads, float padding_value) ; void Fill (float value) ; void Fill (const std::vector<float > &values) ; void Ones () ; void Rand () ; void Show () ; void Flatten () ; private : std::vector<uint32_t > raw_shapes_; arma::fcube data_; };
使用 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 TEST (test_tensor, create) { using namespace kuiper_infer; Tensor<float > tensor (3 , 32 , 32 ) ; ASSERT_EQ (tensor.channels (), 3 ); ASSERT_EQ (tensor.rows (), 32 ); ASSERT_EQ (tensor.cols (), 32 ); ASSERT_EQ (tensor.empty (), false ); } TEST (test_tensor, fill) { using namespace kuiper_infer; Tensor<float > tensor (3 , 3 , 3 ) ; ASSERT_EQ (tensor.channels (), 3 ); ASSERT_EQ (tensor.rows (), 3 ); ASSERT_EQ (tensor.cols (), 3 ); std::vector<float > values; for (int i = 0 ; i < 27 ; ++i) { values.push_back ((float ) i); } tensor.Fill (values); LOG (INFO) << tensor.data (); int index = 0 ; for (int c = 0 ; c < tensor.channels (); ++c) { for (int c_ = 0 ; c_ < tensor.cols (); ++c_) { for (int r = 0 ; r < tensor.rows (); ++r) { ASSERT_EQ (values.at (index), tensor.at (c, c_, r)); index += 1 ; } } } LOG (INFO) << "Test1 passed!" ; } TEST (test_tensor, padding1) { using namespace kuiper_infer; Tensor<float > tensor (3 , 3 , 3 ) ; ASSERT_EQ (tensor.channels (), 3 ); ASSERT_EQ (tensor.rows (), 3 ); ASSERT_EQ (tensor.cols (), 3 ); tensor.Fill (1.f ); tensor.Padding ({1 , 1 , 1 , 1 }, 0 ); ASSERT_EQ (tensor.rows (), 5 ); ASSERT_EQ (tensor.cols (), 5 ); int index = 0 ; for (int c = 0 ; c < tensor.channels (); ++c) { for (int c_ = 0 ; c_ < tensor.cols (); ++c_) { for (int r = 0 ; r < tensor.rows (); ++r) { if (c_ == 0 || r == 0 ) { ASSERT_EQ (tensor.at (c, c_, r), 0 ); } index += 1 ; } } } LOG (INFO) << "Test2 passed!" ; }