本文以arma::cube为基础实现了Tensor类,提供了更方便的访问方式和对外接口。

Tensor类模板

C++类模板例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// 类模板
template <typename T1, typename T2> // 或者template<class T1, class T2>
class Complex {
public:
Complex(T1 a, T2 b) : _a(a), _b(b);
Complex<T1, T2> operator+(Complex<T1, T2> &c);

private:
T1 _a;
T2 _b;
};

// 类模板-全特化(具体化)
template <>
class Complex<int, int> {
public:
Complex(int a, int b) : _a(a), _b(b);
Complex<int, int> operator+(Complex<int, int> &c);

private:
int _a;
int _b;
};

Complex<int, int> c1(1,2);

模板分为类模板与函数模板,特化分为全特化与偏特化(partial specialization)。

对于模板、模板的全特化和模板的偏特化, 以及同名普通函数都存在的情况下,编译器在编译阶段进行匹配时,只匹配普通函数和模板, 匹配顺序如下:

  1. 查找普通函数中有没有匹配的,如果有就选它
  2. 查找模板中有没有匹配的, 并选则最匹配的版本, 然后进行下面两步

注意, 上面规则没提到特化版本, 如果编译器匹配到了规则2, 然后才进行特化版本的匹配

  1. 查找全特化版本中有没有匹配的
  2. 查找偏特化版本中有没有匹配的

Tensor共有两个类型,一个类型是Tensor,另一个类型是Tensor, Tensor 可能会在后续的量化课程中进行使用,目前还暂时未实现.
我们把Tensor和Tensor全特化,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
template<typename T>
class Tensor {

};

template<>
class Tensor<uint8_t> {
// 待实现
};

template<>
class Tensor<float> {
// 待实现
}

张量

const小结
常量引用

  • const修饰实参:表示不能改变实参的值
  • const成员函数:表示不能改变所有成员变量的值
  • 常量引用:不能通过引用修改其所绑定的对象,但能以其它方式修改这个对象。

构造函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  // 默认
explicit Tensor() = default;
// 参数构造
explicit Tensor(uint32_t channels, uint32_t rows, uint32_t cols);
// 拷贝构造
Tensor(const Tensor &tensor);
// 赋值构造
Tensor<float> &operator=(const Tensor &tensor);

Tensor<float>::Tensor(uint32_t channels, uint32_t rows, uint32_t cols) {
data_ = arma::fcube(rows, cols, channels);
}

Tensor<float>::Tensor(const Tensor &tensor) {
this->data_ = tensor.data_;
this->raw_shapes_ = tensor.raw_shapes_;
}

Tensor<float> &Tensor<float>::operator=(const Tensor &tensor) {
if (this != &tensor) {
this->data_ = tensor.data_;
this->raw_shapes_ = tensor.raw_shapes_;
}
return *this;
}

张量的维度大小

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  // 返回行数
uint32_t rows() const;
// 返回列数
uint32_t cols() const;
// 返回通道数
uint32_t channels() const;
// 返回张量中元素的个数
uint32_t size() const;
// 返回每个维度大小
std::vector<uint32_t> shapes() const;

uint32_t Tensor<float>::rows() const {
CHECK(!this->data_.empty());
return this->data_.n_rows;
}

uint32_t Tensor<float>::cols() const {
CHECK(!this->data_.empty());
return this->data_.n_cols;
}

uint32_t Tensor<float>::channels() const {
CHECK(!this->data_.empty());
return this->data_.n_slices;
}

uint32_t Tensor<float>::size() const {
CHECK(!this->data_.empty());
return this->data_.size();
}

std::vector<uint32_t> Tensor<float>::shapes() const {
CHECK(!this->data_.empty());
return {this->channels(), this->rows(), this->cols()};
}

取数据和索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  // 返回data
arma::fcube &data();
const arma::fcube &data() const;

// 返回某一通道的数据
arma::fmat &at(uint32_t channel);
const arma::fmat &at(uint32_t channel) const;

// 索引,返回data[channel, row, col]
float at(uint32_t channel, uint32_t row, uint32_t col) const;
float &at(uint32_t channel, uint32_t row, uint32_t col);

// 返回 第channel * Rows * Cols + row * Cols + col 个数据,相当于展平为1维后的索引
float index(uint32_t offset) const;

arma::fcube &Tensor<float>::data() {
return this->data_;
}

const arma::fcube &Tensor<float>::data() const {
return this->data_;
}

arma::fmat &Tensor<float>::at(uint32_t channel) {
CHECK_LT(channel, this->channels());
return this->data_.slice(channel);
}

const arma::fmat &Tensor<float>::at(uint32_t channel) const {
CHECK_LT(channel, this->channels());
return this->data_.slice(channel);
}

float Tensor<float>::at(uint32_t channel, uint32_t row, uint32_t col) const {
CHECK_LT(row, this->rows());
CHECK_LT(col, this->cols());
CHECK_LT(channel, this->channels());
return this->data_.at(row, col, channel);
}

float &Tensor<float>::at(uint32_t channel, uint32_t row, uint32_t col) {
CHECK_LT(row, this->rows());
CHECK_LT(col, this->cols());
CHECK_LT(channel, this->channels());
return this->data_.at(row, col, channel);
}

float Tensor<float>::index(uint32_t offset) const {
CHECK(offset < this->data_.size());
return this->data_.at(offset);
}

初始化张量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
  // 赋值
void set_data(const arma::fcube &data);
// 全为1
void Ones();
// 随机
void Rand();

void Tensor<float>::set_data(const arma::fcube &data) {
CHECK(data.n_rows == this->data_.n_rows) << data.n_rows << " != " << this->data_.n_rows;
CHECK(data.n_cols == this->data_.n_cols) << data.n_cols << " != " << this->data_.n_cols;
CHECK(data.n_slices == this->data_.n_slices) << data.n_slices << " != " << this->data_.n_slices;
this->data_ = data;
}

void Tensor<float>::Rand() {
CHECK(!this->data_.empty());
this->data_.randn();
}

void Tensor<float>::Ones() {
CHECK(!this->data_.empty());
this->data_.fill(1.);
}

张量填充

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  // 边界填充
void Padding(const std::vector<uint32_t> &pads, float padding_value);
// 用标量值填充
void Fill(float value);
// 用vector填充
void Fill(const std::vector<float> &values);

void Tensor<float>::Padding(const std::vector<uint32_t> &pads, float padding_value) {
// Usage: tensor.Padding({1, 1, 1, 1}, 0); // 边缘填充为0
CHECK(!this->data_.empty());
CHECK_EQ(pads.size(), 4);
uint32_t pad_rows1 = pads.at(0); // up
uint32_t pad_rows2 = pads.at(1); // bottom
uint32_t pad_cols1 = pads.at(2); // left
uint32_t pad_cols2 = pads.at(3); // right

// at column 0, insert a copy of pad_rows1;
this->data_.insert_rows(0, pad_rows1);
this->data_.insert_rows(this->data_.n_rows, pad_rows2);
this->data_.insert_cols(0, pad_cols1);
this->data_.insert_cols(this->data_.n_cols, pad_cols2);
}

void Tensor<float>::Fill(float value) {
CHECK(!this->data_.empty());
this->data_.fill(value);
}

void Tensor<float>::Fill(const std::vector<float> &values) {
CHECK(!this->data_.empty());
const uint32_t total_elems = this->data_.size();
CHECK_EQ(values.size(), total_elems);

const uint32_t rows = this->rows();
const uint32_t cols = this->cols();
const uint32_t planes = rows * cols;
const uint32_t channels = this->data_.n_slices;

for(uint32_t i = 0; i < channels; i++) {
auto &channel_data = this->data_.slice(i);
// values.data() 返回指向作为元素存储工作的底层数组的指针
const arma::fmat &channel_data_t = arma::fmat(values.data() + i * planes, this->cols(), this->rows());
channel_data = channel_data_t.t();
}
}

其他

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  // 判断张量是否为空,未初始化
bool empty() const;
// 输出张量
void Show();
// 展平张量
void Flatten();

bool Tensor<float>::empty() const {
return this->data_.empty();
}

void Tensor<float>::Show() {
for (uint32_t i = 0; i < this->channels(); ++i) {
LOG(INFO) << "Channel: " << i;
LOG(INFO) << "\n" << this->data_.slice(i);
}
}

void Tensor<float>::Flatten() {
CHECK(!this->data_.empty());
const uint32_t size = this->data_.size();
arma::fcube linear_cube(size, 1, 1);

uint32_t channel = this->channels();
uint32_t rows = this->rows();
uint32_t cols = this->cols();
uint32_t index = 0;

for (uint32_t c = 0; c < channel; ++c) {
const arma::fmat &matrix = this->data_.slice(c);

for (uint32_t r = 0; r < rows; ++r) {
for (uint32_t c_ = 0; c_ < cols; ++c_) {
linear_cube.at(index, 0, 0) = matrix.at(r, c_);
index += 1;
}
}
}
CHECK_EQ(index, size);
this->data_ = linear_cube;
this->raw_shapes_ = std::vector<uint32_t>{size};
}

完整的接口定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
template<>
class Tensor<float> {
public:
explicit Tensor() = default;

explicit Tensor(uint32_t channels, uint32_t rows, uint32_t cols);

Tensor(const Tensor &tensor);

Tensor<float> &operator=(const Tensor &tensor);

uint32_t rows() const;

uint32_t cols() const;

uint32_t channels() const;

uint32_t size() const;

void set_data(const arma::fcube &data);

bool empty() const;

float index(uint32_t offset) const;

std::vector<uint32_t> shapes() const;

arma::fcube &data();

const arma::fcube &data() const;

arma::fmat &at(uint32_t channel);

const arma::fmat &at(uint32_t channel) const;

float at(uint32_t channel, uint32_t row, uint32_t col) const;

float &at(uint32_t channel, uint32_t row, uint32_t col);

void Padding(const std::vector<uint32_t> &pads, float padding_value);

void Fill(float value);

void Fill(const std::vector<float> &values);

void Ones();

void Rand();

void Show();

void Flatten();

private:
std::vector<uint32_t> raw_shapes_;
arma::fcube data_;
};

使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
TEST(test_tensor, create) {
using namespace kuiper_infer;
Tensor<float> tensor(3, 32, 32);
ASSERT_EQ(tensor.channels(), 3);
ASSERT_EQ(tensor.rows(), 32);
ASSERT_EQ(tensor.cols(), 32);
ASSERT_EQ(tensor.empty(), false);
}

TEST(test_tensor, fill) {
using namespace kuiper_infer;
Tensor<float> tensor(3, 3, 3);
ASSERT_EQ(tensor.channels(), 3);
ASSERT_EQ(tensor.rows(), 3);
ASSERT_EQ(tensor.cols(), 3);

std::vector<float> values;
for (int i = 0; i < 27; ++i) {
values.push_back((float) i);
}
tensor.Fill(values);
LOG(INFO) << tensor.data();

int index = 0;
for (int c = 0; c < tensor.channels(); ++c) {
for (int c_ = 0; c_ < tensor.cols(); ++c_) {
for (int r = 0; r < tensor.rows(); ++r) {
ASSERT_EQ(values.at(index), tensor.at(c, c_, r));
index += 1;
}
}
}
LOG(INFO) << "Test1 passed!";
}

TEST(test_tensor, padding1) {
using namespace kuiper_infer;
Tensor<float> tensor(3, 3, 3);
ASSERT_EQ(tensor.channels(), 3);
ASSERT_EQ(tensor.rows(), 3);
ASSERT_EQ(tensor.cols(), 3);

tensor.Fill(1.f); // 填充为1
tensor.Padding({1, 1, 1, 1}, 0); // 边缘填充为0
ASSERT_EQ(tensor.rows(), 5);
ASSERT_EQ(tensor.cols(), 5);

int index = 0;
// 检查一下边缘被填充的行、列是否都是0
for (int c = 0; c < tensor.channels(); ++c) {
for (int c_ = 0; c_ < tensor.cols(); ++c_) {
for (int r = 0; r < tensor.rows(); ++r) {
if (c_ == 0 || r == 0) {
ASSERT_EQ(tensor.at(c, c_, r), 0);
}
index += 1;
}
}
}
LOG(INFO) << "Test2 passed!";
}