FPGA多通道卷积加速器:从零构建手写识别的硬件引擎
我最近在从事一项很有意思的项目,我想在PFGA上部署CNN并实现手写图片的识别。而本篇文章,是我迈出的第二步。具体代码已发布在github上
模块介绍
卷积神经网络(CNN)可以分为卷积层、池化层、激活层、全链接层结构。本篇实现的,就是CNN的卷积层中的卷积运算模块。
卷积运算的过程如下图所示:
在权重参数已经确定的情况下,我们可以将这过程看成数据滑窗和卷积运算的这两个步骤的重复运算。在前文中,我们已经实现了window模块,而此处我们实现卷积运算模块。
运算过程如下:
[ 1 2 3 4 5 6 7 8 9 ] ∗ [ 1 0 1 0 1 0 1 1 2 ] = 1 ⋅ 1 + 2 ⋅ 0 + 3 ⋅ 1 + 4 ⋅ 0 + 5 ⋅ 1 + 6 ⋅ 0 + 7 ⋅ 1 + 8 ⋅ 1 + 9 ⋅ 2 = 42 \begin{bmatrix}1&2&3\\4&5&6\\7&8&9\end{bmatrix} \ast \begin{bmatrix}1&0&1\\ 0&1&0\\1&1&2 \end{bmatrix}=1\cdot1 +2 \cdot0 +3\cdot1+4\cdot0+5\cdot 1+6\cdot0+7\cdot1+8\cdot 1+9\cdot 2 \\ =42 147258369 ∗ 101011102 =1⋅1+2⋅0+3⋅1+4⋅0+5⋅1+6⋅0+7⋅1+8⋅1+9⋅2=42
代码
- 模块可配置参数、输入和输出定义
为了支持多通道并行处理,输入为所有输入通道展平后的数据,如一维的窗口数据和权重参数
DATA_WIDTH和WEIGHT_WIDTH分开定义,因为后续工作中会对权重定点数量化
module mult_acc_comb #(parameter DATA_WIDTH = 8,parameter KERNEL_SIZE = 3,parameter IN_CHANNEL = 3,parameter WEIGHT_WIDTH = 8,parameter OUTPUT_WIDTH = 20, // 可配置的输出位宽parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL) // Ensure ACC_WIDTH is sufficient
)(// 输入数据接口input window_valid,input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in,input weight_valid,input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in,// 输出数据接口output [OUTPUT_WIDTH-1:0] conv_out, // 使用可配置的输出位宽output conv_valid
);
- 定义内部相关信号
// 计算权重相关参数
localparam WEIGHTS_PER_FILTER = IN_CHANNEL * KERNEL_SIZE * KERNEL_SIZE;// 解包后的多通道窗口数据和权重数据,无符号
wire [DATA_WIDTH-1:0] channel_window_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1];
wire [WEIGHT_WIDTH-1:0] channel_weight_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1]; // 每个通道每个位置的乘法结果,无符号
wire [DATA_WIDTH+WEIGHT_WIDTH-1:0] mult_results [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1]; // 每个通道的累加结果
wire [ACC_WIDTH-1:0] channel_sums [0:IN_CHANNEL-1];// 最终跨通道累加结果
wire [ACC_WIDTH-1:0] total_sum; // 循环变量
genvar ch, i_idx, k_idx, c_idx;
- 输入数据解包
generatefor (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : unpack_genfor (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : element_gen// 解包窗口数据assign channel_window_data[ch][i_idx] = multi_channel_window_in[(ch*KERNEL_SIZE*KERNEL_SIZE + i_idx)*DATA_WIDTH +: DATA_WIDTH ];// 解包权重数据assign channel_weight_data[ch][i_idx] = multi_channel_weight_in[(WEIGHTS_PER_FILTER - 1 - (ch*KERNEL_SIZE*KERNEL_SIZE + i_idx))*WEIGHT_WIDTH +: WEIGHT_WIDTH];endend
endgenerate
a[ b +: c ]的含义是,从a的b位,向上提取c位,也就是a[b+c:b+1];
输入的window和weight的数据结构变化如下
- 并行卷积运算
所有通道同时进行卷积运算
// 并行乘法 - 所有通道所有位置同时计算
generatefor (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : mult_ch_genfor (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : mult_elem_genassign mult_results[ch][i_idx] = channel_window_data[ch][i_idx] * channel_weight_data[ch][i_idx];endend
endgenerate// 每个通道内累加 - 使用组合逻辑加法树
generatefor (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : sum_ch_genif (KERNEL_SIZE == 3) begin : kernel3_sumassign channel_sums[ch] = mult_results[ch][0] + mult_results[ch][1] + mult_results[ch][2] +mult_results[ch][3] + mult_results[ch][4] + mult_results[ch][5] +mult_results[ch][6] + mult_results[ch][7] + mult_results[ch][8];end else begin : general_sumwire [ACC_WIDTH-1:0] partial_sums [0:KERNEL_SIZE*KERNEL_SIZE-1];assign partial_sums[0] = mult_results[ch][0];for (k_idx = 1; k_idx < KERNEL_SIZE*KERNEL_SIZE; k_idx = k_idx + 1) begin : acc_genassign partial_sums[k_idx] = partial_sums[k_idx-1] + mult_results[ch][k_idx];endassign channel_sums[ch] = partial_sums[KERNEL_SIZE*KERNEL_SIZE-1];endend
endgenerate
- 跨通道累加并输出
对所有通道结果进行相加,进行饱和处理,然后输出
// 跨通道累加 - 组合逻辑
generateif (IN_CHANNEL == 3) begin : channel3_sumassign total_sum = channel_sums[0] + channel_sums[1] + channel_sums[2];end else begin : general_channel_sumwire [ACC_WIDTH-1:0] channel_partial_sums [0:IN_CHANNEL-1];assign channel_partial_sums[0] = channel_sums[0];for (c_idx = 1; c_idx < IN_CHANNEL; c_idx = c_idx + 1) begin : ch_acc_genassign channel_partial_sums[c_idx] = channel_partial_sums[c_idx-1] + channel_sums[c_idx];endassign total_sum = channel_partial_sums[IN_CHANNEL-1];end
endgenerate// 输出逻辑 - 组合逻辑
assign conv_valid = window_valid && weight_valid;
assign conv_out = conv_valid ? saturate(total_sum) : {OUTPUT_WIDTH{1'b0}};// 饱和处理函数(组合逻辑)- UNSIGNED
function [OUTPUT_WIDTH-1:0] saturate;input [ACC_WIDTH-1:0] value; // UNSIGNEDlocalparam [ACC_WIDTH-1:0] MAX_UNSIGNED_VAL_SAT = (1 << OUTPUT_WIDTH) - 1;// MIN_UNSIGNED_VAL is 0beginif (value > MAX_UNSIGNED_VAL_SAT)saturate = MAX_UNSIGNED_VAL_SAT[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取elsesaturate = value[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取end
endfunction
测试
mult_acc_comb_tb.v
为验证其功能性,使用多个case经行测试,并对比结果
`timescale 1ns / 1psmodule mult_acc_comb_tb;parameter DATA_WIDTH = 8;
parameter KERNEL_SIZE = 3;
parameter IN_CHANNEL = 3;
parameter WEIGHT_WIDTH = 8;
parameter OUTPUT_WIDTH = 20;
parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL);reg window_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in;
reg weight_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in;wire [OUTPUT_WIDTH-1:0] conv_out;
wire conv_valid;localparam MAX_UNSIGNED_OUT_VAL = (1 << OUTPUT_WIDTH) - 1;// Example: Test 2 raw sum for unsigned context
localparam EXPECTED_SUM_TEST2_UNSIGNED_RAW = 3 * 9 * 2 * 3; // 162
localparam EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT = (EXPECTED_SUM_TEST2_UNSIGNED_RAW > MAX_UNSIGNED_OUT_VAL) ? MAX_UNSIGNED_OUT_VAL : EXPECTED_SUM_TEST2_UNSIGNED_RAW;
localparam MAX_ELEMENT_VAL_TB = (1 << DATA_WIDTH) -1;
localparam MAX_WEIGHT_ELEMENT_VAL_TB = (1 << WEIGHT_WIDTH) -1;mult_acc_comb #(.DATA_WIDTH(DATA_WIDTH),.KERNEL_SIZE(KERNEL_SIZE),.IN_CHANNEL(IN_CHANNEL),.WEIGHT_WIDTH(WEIGHT_WIDTH),.OUTPUT_WIDTH(OUTPUT_WIDTH),.ACC_WIDTH(ACC_WIDTH)
) dut (.window_valid(window_valid),.multi_channel_window_in(multi_channel_window_in),.weight_valid(weight_valid),.multi_channel_weight_in(multi_channel_weight_in),.conv_out(conv_out),.conv_valid(conv_valid)
);reg all_tests_passed_flag;
integer test_id_counter;
integer num_errors;// Task to check results and display Expected/Actual for all
task check_and_report;input [OUTPUT_WIDTH-1:0] expected_out_val;input expected_valid_val;// Test description is displayed before calling this taskbegintest_id_counter = test_id_counter + 1;// Always display Expected and Actual$display(" Expected: conv_valid=%b, conv_out=%d", expected_valid_val, expected_out_val);$display(" Actual: conv_valid=%b, conv_out=%d", conv_valid, conv_out);if (conv_valid === expected_valid_val &&( (expected_valid_val === 1'b0) ? (conv_out === {OUTPUT_WIDTH{1'b0}}) : (conv_out === expected_out_val) ) ) begin$display(" Test ID %0d: Status: PASSED", test_id_counter);end else begin$display(" Test ID %0d: Status: FAILED", test_id_counter);all_tests_passed_flag = 1'b0;num_errors = num_errors + 1;end$display("--------------------------------------------------");end
endtaskinitial begin$display("=== Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=%0d) ===", OUTPUT_WIDTH);all_tests_passed_flag = 1'b1; test_id_counter = 0;num_errors = 0;// Initializewindow_valid = 0;weight_valid = 0;multi_channel_window_in = 0;multi_channel_weight_in = 0;#10;// Test 1$display("Test Description: Simple Positive Values (1*1, sum 27)");multi_channel_window_in = {27{8'd1}}; multi_channel_weight_in = {27{8'd1}}; window_valid = 1;weight_valid = 1;#1; check_and_report(27, 1'b1);#10;// Test 2$display("Test Description: Positive Values with Saturation (2*3, raw %0d, sat %0d)", EXPECTED_SUM_TEST2_UNSIGNED_RAW, EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT);multi_channel_window_in = {27{8'd2}};multi_channel_weight_in = {27{8'd3}};#1; check_and_report(EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT, 1'b1);#10;// Test 3$display("Test Description: Invalid Inputs (both valid_n low)");window_valid = 0;weight_valid = 0;#1;check_and_report(0, 1'b0); #10;// Test 4$display("Test Description: Zero Window Data, Non-zero Weights");window_valid = 1;weight_valid = 1;multi_channel_window_in = {27{8'd0}}; multi_channel_weight_in = {27{8'd5}}; #1;check_and_report(0, 1'b1);#10;// Test 5$display("Test Description: Non-zero Window, Zero Weight Data");multi_channel_window_in = {27{8'd5}}; multi_channel_weight_in = {27{8'd0}}; #1;check_and_report(0, 1'b1);#10;// Test 6$display("Test Description: All Zero Inputs");multi_channel_window_in = {27{8'd0}}; multi_channel_weight_in = {27{8'd0}}; #1;check_and_report(0, 1'b1);#10;// Test 7$display("Test Description: Large values (no saturation with 20-bit output)");multi_channel_window_in = {27{8'd5}}; multi_channel_weight_in = {27{8'd5}}; #1;check_and_report(27*5*5, 1'b1); // 27*25 = 675, well within 20-bit range#10;// Test 8$display("Test Description: Max Val Inputs (Win=%d, Wgt=%d), should saturate to %d", MAX_ELEMENT_VAL_TB, MAX_WEIGHT_ELEMENT_VAL_TB, MAX_UNSIGNED_OUT_VAL);multi_channel_window_in = {27{{DATA_WIDTH{1'b1}}}};multi_channel_weight_in = {27{{WEIGHT_WIDTH{1'b1}}}};#1;// 27 * 255 * 255 = 1,759,725, which exceeds 20-bit max (1,048,575), so should saturatecheck_and_report(MAX_UNSIGNED_OUT_VAL, 1'b1);#10;// Test 8.5: Test 20-bit range capability$display("Test Description: Medium values to test 20-bit range (100*100, sum 270000)");multi_channel_window_in = {27{8'd100}}; multi_channel_weight_in = {27{8'd100}}; #1;check_and_report(27*100*100, 1'b1); // 27*10000 = 270000, well within 20-bit range#10;// Test 9: Window valid toggles$display("--- Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) ---");multi_channel_window_in = {27{8'd1}};multi_channel_weight_in = {27{8'd1}};weight_valid = 1; $display(" Sub-Test Description: WinValid=1 (Start)");window_valid = 1; #1; check_and_report(27, 1'b1);$display(" Sub-Test Description: WinValid=0");window_valid = 0; #1; check_and_report(0, 1'b0);$display(" Sub-Test Description: WinValid=1 (End)");window_valid = 1; #1; check_and_report(27, 1'b1);#10;// Test 10: Weight valid toggles$display("--- Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) ---");window_valid = 1; // inputs are still 1s$display(" Sub-Test Description: WeightValid=1 (Start)");weight_valid = 1; #1; check_and_report(27, 1'b1);$display(" Sub-Test Description: WeightValid=0");weight_valid = 0; #1; check_and_report(0, 1'b0);$display(" Sub-Test Description: WeightValid=1 (End)");weight_valid = 1; #1; check_and_report(27, 1'b1);#10;// Final Summary$display("==================================================");if (all_tests_passed_flag) begin$display("FINAL STATUS: SUCCESS! All %0d UNSIGNED Combinational MultAcc tests passed!", test_id_counter);end else begin$display("FINAL STATUS: FAILED. %0d out of %0d UNSIGNED Combinational MultAcc tests did not pass.", num_errors, test_id_counter);end$display("==================================================");$finish;
endendmodule
结果
window模块每个周期传递数据,因而采用组合逻辑实现卷积运算。当输入数据同时有效,也就是window_valid和weight_valid同时为高时,mult_acc_com进行运算,conv_valid拉高,如下图所示
输出打印结果:
=Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=20) =
Test Description: Simple Positive Values (1*1, sum 27)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 1: Status: PASSED
Test Description: Positive Values with Saturation (2*3, raw 162, sat 162)
Expected: conv_valid=1, conv_out= 162
Actual: conv_valid=1, conv_out= 162Test ID 2: Status: PASSED
Test Description: Invalid Inputs (both valid_n low)
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0Test ID 3: Status: PASSED
Test Description: Zero Window Data, Non-zero Weights
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0Test ID 4: Status: PASSED
Test Description: Non-zero Window, Zero Weight Data
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0Test ID 5: Status: PASSED
Test Description: All Zero Inputs
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0Test ID 6: Status: PASSED
Test Description: Large values (no saturation with 20-bit output)
Expected: conv_valid=1, conv_out= 675
Actual: conv_valid=1, conv_out= 675Test ID 7: Status: PASSED
Test Description: Max Val Inputs (Win= 255, Wgt= 255), should saturate to 1048575
Expected: conv_valid=1, conv_out=1048575
Actual: conv_valid=1, conv_out=1048575Test ID 8: Status: PASSED
Test Description: Medium values to test 20-bit range (100*100, sum 270000)
Expected: conv_valid=1, conv_out= 270000
Actual: conv_valid=1, conv_out= 270000Test ID 9: Status: PASSED
— Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WinValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 10: Status: PASSED
Sub-Test Description: WinValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0Test ID 11: Status: PASSED
Sub-Test Description: WinValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 12: Status: PASSED
— Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WeightValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 13: Status: PASSED
Sub-Test Description: WeightValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0Test ID 14: Status: PASSED
Sub-Test Description: WeightValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 15: Status: PASSED