SpringBoot+MybatisPlus+自定义注解+切面实现水平数据隔离功能(附代码下载)
场景
业务场景中,需要对某些表中的数据做水平的数据隔离,比如某些表中如果含有某个字段,比如store_id(门店id)这个字段,
则对某些有对应门店权限的用户角色开放数据,如果请求的用户没有对该门店的权限,则自动对sql进行拦截添加where条件。
当然如果同一张表,又必须要查询全量数据,又可以通过添加自定义注解的方式,跳过数据隔离,返回全量数据。
并且如果用户没有任何门店的权限,或其他类似权限限制,则直接不执行查询,返回数据为空。
注:
博客:
https://blog.csdn.net/badao_liumang_qizhi
实现
新建SpringBoot项目,并引入相关依赖
如下依赖特别关注:
<!--MybatisPlus依赖--><dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.1</version></dependency><dependency><groupId>org.aspectj</groupId><artifactId>aspectjrt</artifactId><version>1.9.7</version></dependency><dependency><groupId>org.aspectj</groupId><artifactId>aspectjweaver</artifactId><version>1.9.7</version></dependency><dependency><groupId>org.springframework</groupId><artifactId>spring-aspects</artifactId></dependency>
注意:
mybatis-plus-boot-starter 3.5.1 已包含 JSqlParser 依赖
所以此处不需要额外引入如下依赖:
<dependency><groupId>com.github.jsqlparser</groupId><artifactId>jsqlparser</artifactId><version>4.3</version> <!-- MyBatis-Plus 3.5.1 使用的版本 -->
</dependency>
另外还需引入其它非关键依赖,按需选择:
<!-- spring-boot --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><!-- spring-boot-test --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope><exclusions><exclusion><groupId>org.junit.vintage</groupId><artifactId>junit-vintage-engine</artifactId></exclusion></exclusions></dependency><!-- lombok --><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><version>1.18.26</version><scope>provided</scope></dependency><!-- 数据库连接 --><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId></dependency>
添加mybatisplus的配置类,在配置类中实现初始化表缓存、定期刷新表缓存、注册数据隔离拦截器操作
代码实现如下:
@Configuration
@MapperScan("com.badao.demo.mapper")
public class MybatisPlusConfig {// 关键功能:// 1. 初始化时扫描数据库表结构(initTableCache)// 2. 定时刷新表结构缓存(scheduleCacheRefresh)// 3. 注册MyBatis-Plus拦截器链public MybatisPlusConfig(DataSource dataSource) {this.dataSource = dataSource;initTableCache();//每5分钟刷新缓存(应对表结构变更)scheduleCacheRefresh();}// 数据源private final DataSource dataSource;// 模式名称public static final String DATABASE_B_GAS_STATION = "test";// 含有门店id字段的数据表// 使用ConcurrentHashMap保证线程安全private final Set<String> tablesWithStoreId = Collections.newSetFromMap(new ConcurrentHashMap<>());@Beanpublic MybatisPlusInterceptor mybatisPlusInterceptor() {MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();// 注册数据水平隔离拦截器interceptor.addInnerInterceptor(new StoreDataInterceptor(tablesWithStoreId));// 注册分页拦截器interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));return interceptor;}/*** 初始化表缓存*/private void initTableCache() {try (Connection conn = dataSource.getConnection()) {DatabaseMetaData metaData = conn.getMetaData();Map<String, Set<String>> tableColumnsMap = new HashMap<>();try (ResultSet columns = metaData.getColumns(DATABASE_B_GAS_STATION, null, "%", "%")) {while (columns.next()) {String tableName = columns.getString("TABLE_NAME").toLowerCase();String columnName = columns.getString("COLUMN_NAME").toLowerCase();tableColumnsMap.computeIfAbsent(tableName, k -> new HashSet<>()).add(columnName);}}// 获取所有表try (ResultSet tables = metaData.getTables(DATABASE_B_GAS_STATION, null, "%", new String[]{"TABLE"})) {while (tables.next()) {String tableName = tables.getString("TABLE_NAME").toLowerCase();Set<String> columns = tableColumnsMap.getOrDefault(tableName, Collections.emptySet());if (columns.contains(StoreDataInterceptor.STORE_ID)) {tablesWithStoreId.add(tableName);}}}} catch (Exception e) {throw new RuntimeException("SellerIso: Failed to init table cache", e);}}/*** 定时刷新表结构缓存*/private void scheduleCacheRefresh() {//刷新表结构缓存ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();scheduler.scheduleAtFixedRate(this::initTableCache, 5, 5, TimeUnit.MINUTES);}
}
数据隔离拦截器实现代码
public class StoreDataInterceptor implements InnerInterceptor {private final Set<String> tablesWithStoreId;//隔离字段columnpublic static final String STORE_ID = "store_id";public StoreDataInterceptor(Set<String> tablesWithStoreId) {this.tablesWithStoreId = tablesWithStoreId;}/*** 优先级高于SQL改写* 若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)*/@Overridepublic boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {// 指定跳过数据隔离if (SkipDataIsolation.getMethodSkipDataIsolation()) {return true;}// 其它业务逻辑则不予查询,比如获取请求头中的数据做权限校验,完全禁止无权限的查询(如未登录用户)
// if(!CollectionUtils.isEmpty(UserContextHolder.getStoreIds())
// {
// return false;
// }return true;}@Overridepublic void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds,ResultHandler resultHandler, BoundSql boundSql) {// 如果用户是超管用户则跳过拦截器-自己添加逻辑判断if (false) {return;}// 指定跳过数据隔离if (SkipDataIsolation.getMethodSkipDataIsolation()) {return;}String sql = boundSql.getSql();try {//解析SQL并重写Select select = (Select) CCJSqlParserUtil.parse(sql);SelectBody selectBody = select.getSelectBody();// 递归处理所有SELECT部分processSelectBody(selectBody);PluginUtils.mpBoundSql(boundSql).sql(select.toString());} catch (Exception e) {System.out.println(e.getMessage());}}/*** 处理PlainSelect*/private void processSelectBody(SelectBody selectBody) {if (selectBody instanceof PlainSelect) {processPlainSelect((PlainSelect) selectBody);} else if (selectBody instanceof SetOperationList) {// 处理UNION/INTERSECT等for (SelectBody body : ((SetOperationList) selectBody).getSelects()) {processSelectBody(body);}}// 其他类型如WithItem暂不处理}/*** 处理FROM项*/private void processPlainSelect(PlainSelect plainSelect) {// 1. 处理FROM项Map<String, String> aliasTableMap = new HashMap<>();processFromItem(plainSelect.getFromItem(), aliasTableMap);// 2. 处理JOIN表if (plainSelect.getJoins() != null) {for (Join join : plainSelect.getJoins()) {processFromItem(join.getRightItem(), aliasTableMap);}}// 3. 添加条件到当前SELECTaddConditionsToSelect(plainSelect, aliasTableMap);// 4. 递归处理子查询processSubQueries(plainSelect);}/*** 处理查询*/private void processFromItem(FromItem fromItem, Map<String, String> aliasTableMap) {if (fromItem instanceof Table) {Table table = (Table) fromItem;String tableName = table.getName().toLowerCase();String alias = table.getAlias() != null ?table.getAlias().getName().toLowerCase() : tableName;// 缓存别名映射aliasTableMap.put(alias, tableName);} else if (fromItem instanceof SubSelect) {// 处理子查询processSelectBody(((SubSelect) fromItem).getSelectBody());}}/*** 处理子查询*/private void processSubQueries(PlainSelect plainSelect) {// 1. 处理WHERE子句中的子查询if (plainSelect.getWhere() != null) {plainSelect.getWhere().accept(new SafeExpressionVisitor());}// 2. 处理SELECT列表中的子查询for (SelectItem item : plainSelect.getSelectItems()) {item.accept(new SafeSelectItemVisitor());}}/*** 添加查询条件到SELECT*/private void addConditionsToSelect(PlainSelect plainSelect, Map<String, String> aliasTableMap) {// 检查哪些表需要添加条件for (Map.Entry<String, String> entry : aliasTableMap.entrySet()) {String alias = entry.getKey();String tableName = entry.getValue();//List<String> storeIds = UserContextHolder.getStoreIds();//此处用模拟数据示例List<String> storeIds = new ArrayList(){{this.add("1");this.add("2");}};//对含store_id的表自动添加条件:if (tablesWithStoreId.contains(tableName)) {handleSelectSql(alias, plainSelect, storeIds);}}}/*** 创建查询表达式*/private static void handleSelectSql(String alias, PlainSelect plainSelect,List<String> companyChannelIds) {// 创建条件表达式Column channelColumn = new Column(alias + "." + StoreDataInterceptor.STORE_ID);// 创建表达式列表ExpressionList expressionList = new ExpressionList();// 手动初始化expressions列表expressionList.setExpressions(new ArrayList<>());for (String id : companyChannelIds) {expressionList.getExpressions().add(new StringValue(id));}// 构建条件表达式:WHERE (store_id IN (1,2) OR store_id IS NULL)// 创建IN表达式InExpression inExpression = new InExpression(channelColumn, expressionList);// 添加or 数据隔离字段is null 条件避免联表查询时未能关联数据导致全部数据被过滤IsNullExpression isNullExpression = new IsNullExpression();isNullExpression.setLeftExpression(channelColumn);OrExpression orExpression = new OrExpression(isNullExpression, inExpression);// 调整or条件优先级 加()Parenthesis parenthesis = new Parenthesis(orExpression);// 获取现有WHERE条件Expression where = plainSelect.getWhere();plainSelect.setWhere(where == null ? parenthesis : new AndExpression(where, parenthesis));}/*** 避免查询无限递归*/private class SafeExpressionVisitor extends ExpressionVisitorAdapter {private final Set<Object> visitedObjects = Collections.newSetFromMap(new IdentityHashMap<>());@Overridepublic void visit(SubSelect subSelect) {// 防止SubSelect无限递归if (visitedObjects.add(subSelect)) {try {// 限制递归深度if (visitedObjects.size() < 50) {processSelectBody(subSelect.getSelectBody());}} catch (Exception e) {System.out.println(e.getMessage());} finally {visitedObjects.remove(subSelect);}}}@Overridepublic void visit(AllColumns allColumns) {// 关键:避免处理AllColumns时的无限递归// 在JSqlParser 4.3中,这里不能调用super.visit(allColumns)}}/*** 避免查询无限递归*/private class SafeSelectItemVisitor extends SelectItemVisitorAdapter {@Overridepublic void visit(SelectExpressionItem item) {try {item.getExpression().accept(new SafeExpressionVisitor());} catch (Exception e) {System.out.println(e.getMessage());}}}
}
代码如下:
注意:
1、willDoQuery中
核心作用
拦截器开关控制
决定是否允许当前SQL查询继续执行(true放行,false拦截)
与跳过机制集成
通过检查SkipDataIsolation的线程状态,实现动态拦截控制
方法调用时机
sequenceDiagram
MyBatis->>StoreDataInterceptor: 执行查询前
StoreDataInterceptor->>willDoQuery: 检查拦截条件
alt 返回true
MyBatis->>DB: 正常执行查询
else 返回false
MyBatis->>调用方: 直接返回空结果
end
优先级高于SQL改写
若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)
典型使用场景
完全禁止无权限的查询(如未登录用户)
快速跳过无需处理的查询类型(如特定Mapper方法)
2、beforeQuery中
如果用户是超管用户则跳过拦截器-自己添加逻辑判断
addConditionsToSelect添加查询条件中SELECT中,获取当前用户的门店id权限使用模拟数据演示效果。
正常应该是从权限控制相关业务中获取,此处注意使用时修改。
自定义跳过数据隔离注解实现
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SkipDataIsolationAnnotation {
}
跳过数据隔离切面实现
/*** 跳过数据隔离切面*/
@Aspect
@Component
public class SkipDataIsolationAspect {//Around增强:在方法执行前后插入逻辑@Around("@annotation(skipDataIsolationAnnotation)")public Object handleSkipDataIsolation(ProceedingJoinPoint joinPoint,SkipDataIsolationAnnotation skipDataIsolationAnnotation) throws Throwable {try {//进入方法时设置ThreadLocal标志为trueSkipDataIsolation.setMethodSkipDataIsolation(true);// 设置线程标志return joinPoint.proceed();} finally {//通过try-finally确保异常时也能清理状态SkipDataIsolation.methodClear(); // 清理线程状态}}
}
上下文控制器实现
/*** 上下文控制器*/
public class SkipDataIsolation {// 单次sql语句级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志,默认不跳过value=falseprivate static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION = ThreadLocal.withInitial(() -> false);// 方法级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志, 默认不跳过value=falseprivate static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION_METHOD = ThreadLocal.withInitial(() -> false);/*** 单次sql级别:设置跳过数据隔离标志*/public static void setSkipDataIsolation(Boolean skip) {SKIP_DATA_ISOLATION.set(skip);}/*** 单次sql级别:获取跳过数据隔离标志*/public static Boolean getSkipDataIsolation() {return SKIP_DATA_ISOLATION.get();}/*** 单次sql级别:清理ThreadLocal,防止内存泄漏*/public static void clear() {SKIP_DATA_ISOLATION.remove();}/*** 方法级别:设置跳过数据隔离标志*/public static void setMethodSkipDataIsolation(Boolean skip) {SKIP_DATA_ISOLATION_METHOD.set(skip);}/*** 方法级别:获取跳过数据隔离标志*/public static Boolean getMethodSkipDataIsolation() {return SKIP_DATA_ISOLATION_METHOD.get();}/*** 方法级别:清理ThreadLocal,防止内存泄漏*/public static void methodClear() {SKIP_DATA_ISOLATION_METHOD.remove();}
}
测试效果
新建一个包含store_id字段的表,并生成5条数据,其中有两条数据store_id为1和2。
新建两个controller并且一个添加跳过数据隔离注解,一个不添加,执行同样的mp的条件查询。
不进行数据隔离的查询效果

带数据隔离的效果
完整示例代码以及SQL文件资源下载
https://download.csdn.net/download/BADAO_LIUMANG_QIZHI/92218402

