diff --git a/ruoyi-admin/src/main/resources/application-druid.yml b/ruoyi-admin/src/main/resources/application-druid.yml index 6c5e4ae3c889a4d92c594922896104cbec42418f..489a38b9471b9be7683ce6d8bd90df809c9ecbc5 100644 --- a/ruoyi-admin/src/main/resources/application-druid.yml +++ b/ruoyi-admin/src/main/resources/application-druid.yml @@ -10,7 +10,7 @@ spring: MASTER: url: jdbc:mysql://127.0.0.1/ry?useUnicode=true&characterEncoding=utf8&zeroDateTimeBehavior=convertToNull&useSSL=true&serverTimezone=GMT%2B8 username: root - password: 123456 + password: WKY20031018 # 从库数据源 # SLAVE: # url: jdbc:mysql://127.0.0.1/ruoyi?useUnicode=true&characterEncoding=utf8&zeroDateTimeBehavior=convertToNull&useSSL=true&serverTimezone=GMT%2B8 diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/annotation/DataSource.java b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/DataSource.java index 39d9d30ed7b7390d7a2b78e1b2864088e307a727..79cd191f8e2f85e0a28dae98880a88dd134e49fc 100644 --- a/ruoyi-common/src/main/java/com/ruoyi/common/annotation/DataSource.java +++ b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/DataSource.java @@ -6,7 +6,6 @@ import java.lang.annotation.Inherited; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; - import com.ruoyi.common.enums.DataSourceType; /** @@ -20,15 +19,10 @@ import com.ruoyi.common.enums.DataSourceType; @Retention(RetentionPolicy.RUNTIME) @Documented @Inherited -public @interface DataSource { - +public @interface DataSource +{ /** - * 切换数据源名称 - 枚举方式 + * 切换数据源名称 */ public DataSourceType value() default DataSourceType.MASTER; - - /** - * 切换数据源名称 - 字符串方式 - */ - public String name() default ""; } diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/annotation/sql/DataSecurity.java b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/sql/DataSecurity.java new file mode 100644 index 0000000000000000000000000000000000000000..6d14bd19a363cc7ef2ffb0af884b8e4b2dd4c028 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/sql/DataSecurity.java @@ -0,0 +1,20 @@ +package com.ruoyi.common.annotation.sql; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import com.ruoyi.common.enums.DataSecurityStrategy; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface DataSecurity { + public DataSecurityStrategy strategy() default DataSecurityStrategy.CREEATE_BY; + + public String table() default ""; + + public String joinTableAlise() default ""; +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/annotation/sql/MybatisHandlerOrder.java b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/sql/MybatisHandlerOrder.java new file mode 100644 index 0000000000000000000000000000000000000000..4c75c030dd4ec8ad7a323da420e5f5b0396e720b --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/sql/MybatisHandlerOrder.java @@ -0,0 +1,14 @@ +package com.ruoyi.common.annotation.sql; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface MybatisHandlerOrder { + public int value() default 0; +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/context/dataSecurity/DataSecurityContextHolder.java b/ruoyi-common/src/main/java/com/ruoyi/common/context/dataSecurity/DataSecurityContextHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..9b381b649e5ee786b9ec5ff4cdf381389eb4fc70 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/context/dataSecurity/DataSecurityContextHolder.java @@ -0,0 +1,48 @@ +package com.ruoyi.common.context.dataSecurity; + +import java.util.List; +import java.util.Map; + +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; +import com.ruoyi.common.enums.SqlType; +import com.ruoyi.common.model.JoinTableModel; +import com.ruoyi.common.model.WhereModel; + +public class DataSecurityContextHolder { + private static final ThreadLocal DATA_SECURITY_SQL_CONTEXT_HOLDER = new ThreadLocal<>(); + + public static void startDataSecurity() { + JSONObject jsonObject = new JSONObject(); + jsonObject.put("isSecurity", Boolean.TRUE); + jsonObject.put(SqlType.WHERE.getSqlType(), new JSONArray()); + jsonObject.put(SqlType.JOIN.getSqlType(), new JSONArray()); + DATA_SECURITY_SQL_CONTEXT_HOLDER.set(jsonObject); + } + + public static void addWhereParam(WhereModel whereModel) { + DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.WHERE.getSqlType()).add(whereModel); + } + + public static void clearCache() { + DATA_SECURITY_SQL_CONTEXT_HOLDER.remove(); + } + + public static boolean isSecurity() { + + return DATA_SECURITY_SQL_CONTEXT_HOLDER.get() != null + && DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getBooleanValue("isSecurity"); + } + + public static JSONArray getWhere() { + return DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.WHERE.getSqlType()); + } + + public static void addJoinTable(JoinTableModel joinTableModel) { + DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.JOIN.getSqlType()).add(joinTableModel); + } + + public static JSONArray getJoinTables() { + return DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.JOIN.getSqlType()); + } +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/context/page/PageContextHolder.java b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/PageContextHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..ad0bcd3cf5ca665abbfc97ea77d609eb5ba554a9 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/PageContextHolder.java @@ -0,0 +1,44 @@ +package com.ruoyi.common.context.page; + +import com.alibaba.fastjson2.JSONObject; +import com.ruoyi.common.context.page.model.PageInfo; + +public class PageContextHolder { + private static final ThreadLocal PAGE_CONTEXT_HOLDER = new ThreadLocal<>(); + + private static final String PAGE_FLAG = "isPage"; + + private static final String PAGE_INFO = "pageInfo"; + + private static final String TOTAL = "total"; + + public static void startPage() { + JSONObject jsonObject = new JSONObject(); + jsonObject.put(PAGE_FLAG, Boolean.TRUE); + PAGE_CONTEXT_HOLDER.set(jsonObject); + } + + public static void setPageInfo() { + PAGE_CONTEXT_HOLDER.get().put(PAGE_INFO, PageInfo.defaultPageInfo()); + } + + public static PageInfo getPageInfo() { + return (PageInfo) PAGE_CONTEXT_HOLDER.get().get(PAGE_INFO); + } + + public static void clear() { + PAGE_CONTEXT_HOLDER.remove(); + } + + public static boolean isPage() { + return PAGE_CONTEXT_HOLDER.get() != null && PAGE_CONTEXT_HOLDER.get().getBooleanValue(PAGE_FLAG); + } + + public static void setTotal(Long total) { + PAGE_CONTEXT_HOLDER.get().put(TOTAL, total); + } + + public static Long getTotal() { + return PAGE_CONTEXT_HOLDER.get().getLong(TOTAL); + } +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/PageInfo.java b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/PageInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..dfca4adcfe4939d7c232f57742b53b71a5ad9c6c --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/PageInfo.java @@ -0,0 +1,63 @@ +package com.ruoyi.common.context.page.model; + +import com.ruoyi.common.core.text.Convert; +import com.ruoyi.common.utils.ServletUtils; + +public class PageInfo { + + private Long pageNumber; + + private Long pageSize; + + /** + * 当前记录起始索引 + */ + public static final String PAGE_NUM = "pageNum"; + + /** + * 每页显示记录数 + */ + public static final String PAGE_SIZE = "pageSize"; + + /** + * 排序列 + */ + public static final String ORDER_BY_COLUMN = "orderByColumn"; + + /** + * 排序的方向 "desc" 或者 "asc". + */ + public static final String IS_ASC = "isAsc"; + + /** + * 分页参数合理化 + */ + public static final String REASONABLE = "reasonable"; + + public Long getPageNumber() { + return pageNumber; + } + + public void setPageNumber(Long pageNumber) { + this.pageNumber = pageNumber; + } + + public Long getPageSize() { + return pageSize; + } + + public void setPageSize(Long pageSize) { + this.pageSize = pageSize; + } + + public static PageInfo defaultPageInfo() { + PageInfo pageInfo = new PageInfo(); + pageInfo.setPageNumber(Long.valueOf(Convert.toInt(ServletUtils.getParameter(PAGE_NUM), 1))); + pageInfo.setPageSize(Long.valueOf(Convert.toInt(ServletUtils.getParameter(PAGE_SIZE), 10))); + return pageInfo; + } + + public Long getOffeset() { + return (pageNumber.longValue() - 1L) * pageSize; + } +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/RuoyiTableData.java b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/RuoyiTableData.java new file mode 100644 index 0000000000000000000000000000000000000000..b4d3711fb20b6dae6838980aa7526bc55629fca9 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/RuoyiTableData.java @@ -0,0 +1,25 @@ +package com.ruoyi.common.context.page.model; + +import java.util.List; + +public class RuoyiTableData { + private Long total; + private List data; + + public Long getTotal() { + return total; + } + + public void setTotal(Long total) { + this.total = total; + } + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/TableInfo.java b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/TableInfo.java new file mode 100644 index 0000000000000000000000000000000000000000..8e0d722368b5eb5f5cef129763bd54e802bfca72 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/context/page/model/TableInfo.java @@ -0,0 +1,22 @@ +package com.ruoyi.common.context.page.model; + +import java.util.ArrayList; +import java.util.List; + +public class TableInfo extends ArrayList { + + private Long total; + + public TableInfo(List list) { + super(list); + } + + public Long getTotal() { + return total; + } + + public void setTotal(Long total) { + this.total = total; + } + +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/enums/DataSecurityStrategy.java b/ruoyi-common/src/main/java/com/ruoyi/common/enums/DataSecurityStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..f3b07701412620c13b8b21610802a0b53faf31ef --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/enums/DataSecurityStrategy.java @@ -0,0 +1,8 @@ +package com.ruoyi.common.enums; + +public enum DataSecurityStrategy { + JOINTABLE_CREATE_BY, + JOINTABLE_USER_ID, + CREEATE_BY, + USER_ID; +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/enums/SqlType.java b/ruoyi-common/src/main/java/com/ruoyi/common/enums/SqlType.java new file mode 100644 index 0000000000000000000000000000000000000000..b100ce7b599b6fc8ab1a2b76b73ed5fd3d93f662 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/enums/SqlType.java @@ -0,0 +1,18 @@ +package com.ruoyi.common.enums; + +public enum SqlType { + WHERE("where"), + JOIN("join"), + SELECT("select"), + LIMIT("limit"); + + private String sqlType; + + public String getSqlType() { + return sqlType; + } + + private SqlType(String sqlType) { + this.sqlType = sqlType; + } +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/MybatisAfterHandler.java b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/MybatisAfterHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..3713b399e0f7fbd8923f0f95f2e0a258a3bd50b1 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/MybatisAfterHandler.java @@ -0,0 +1,7 @@ +package com.ruoyi.common.handler.sql; + +public interface MybatisAfterHandler { + + Object handleObject(Object object) throws Throwable; + +} \ No newline at end of file diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/MybatisPreHandler.java b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/MybatisPreHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..676d97c1a828c315605504fad625d6a7e39a2c00 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/MybatisPreHandler.java @@ -0,0 +1,15 @@ +package com.ruoyi.common.handler.sql; + +import org.apache.ibatis.cache.CacheKey; +import org.apache.ibatis.executor.Executor; +import org.apache.ibatis.mapping.BoundSql; +import org.apache.ibatis.mapping.MappedStatement; +import org.apache.ibatis.session.ResultHandler; +import org.apache.ibatis.session.RowBounds; + +public interface MybatisPreHandler { + + void preHandle(Executor executor, MappedStatement mappedStatement, Object params, + RowBounds rowBounds, ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) + throws Throwable; +} \ No newline at end of file diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/dataSecurity/DataSecurityPreHandler.java b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/dataSecurity/DataSecurityPreHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..1021b09e635a8697ba7ac7aa8dd0312ceb8f8199 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/dataSecurity/DataSecurityPreHandler.java @@ -0,0 +1,100 @@ +package com.ruoyi.common.handler.sql.dataSecurity; + +import java.lang.reflect.Field; +import java.util.List; + +import org.apache.ibatis.cache.CacheKey; +import org.apache.ibatis.executor.Executor; +import org.apache.ibatis.mapping.BoundSql; +import org.apache.ibatis.mapping.MappedStatement; +import org.apache.ibatis.session.ResultHandler; +import org.apache.ibatis.session.RowBounds; +import org.springframework.stereotype.Component; +import org.springframework.util.ReflectionUtils; + +import com.ruoyi.common.annotation.sql.MybatisHandlerOrder; +import com.ruoyi.common.context.dataSecurity.DataSecurityContextHolder; +import com.ruoyi.common.handler.sql.MybatisPreHandler; +import com.ruoyi.common.model.JoinTableModel; +import com.ruoyi.common.model.WhereModel; +import com.ruoyi.common.utils.StringUtils; +import com.ruoyi.common.utils.sql.SqlUtil; + +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Alias; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.Join; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; + +@MybatisHandlerOrder(1) +@Component +public class DataSecurityPreHandler implements MybatisPreHandler { + + private static final Field sqlFiled = ReflectionUtils.findField(BoundSql.class, "sql"); + static { + sqlFiled.setAccessible(true); + } + + @Override + public void preHandle(Executor executor, MappedStatement mappedStatement, Object params, RowBounds rowBounds, + ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws Throwable { + if (DataSecurityContextHolder.isSecurity()) { + Statement sql = parseSql(SqlUtil.parseSql(boundSql.getSql())); + sqlFiled.set(boundSql, sql.toString()); + } + } + + private static Statement parseSql(Statement statement) throws JSQLParserException { + if (statement instanceof Select) { + Select select = (Select) statement; + // plain.setWhere(CCJSqlParserUtil.parseCondExpression(handleWhere(expWhere))); + handleWhere(select); + handleJoin(select); + return select; + } else { + return statement; + } + } + + private static void handleWhere(Select select) throws JSQLParserException { + PlainSelect plain = select.getPlainSelect(); + Expression expWhere = plain.getWhere(); + StringBuilder whereParam = new StringBuilder(" "); + String where = expWhere != null ? expWhere.toString() : null; + if (DataSecurityContextHolder.getWhere() == null || DataSecurityContextHolder.getWhere().size() <= 0) { + return; + } + DataSecurityContextHolder.getWhere().forEach(item -> { + whereParam.append(((WhereModel) item).getSqlString()); + }); + where = StringUtils.isEmpty(where) ? whereParam.toString().substring(5, whereParam.length()) + : where + " " + whereParam.toString(); + plain.setWhere(CCJSqlParserUtil.parseCondExpression(where)); + } + + private static void handleJoin(Select select) { + PlainSelect selectBody = select.getPlainSelect(); + if (DataSecurityContextHolder.getJoinTables() == null || DataSecurityContextHolder.getJoinTables().size() <= 0) { + return; + } + DataSecurityContextHolder.getJoinTables().forEach(item -> { + JoinTableModel tableModel = (JoinTableModel) item; + Table table = new Table(tableModel.getJoinTable()); + table.setAlias(new Alias(tableModel.getJoinTableAlise())); + Join join = new Join(); + join.setRightItem(table); + join.setInner(true); + Expression onExpression = new EqualsTo(new Column(tableModel.getFromTableColumnString()), + new Column(tableModel.getJoinTableColumnString())); + join.setOnExpressions(List.of(onExpression)); + selectBody.addJoins(join); + }); + } + +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/page/PageAfterHandler.java b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/page/PageAfterHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..8a5d670e2daa2f8029746d1592eaa7856bbb63ba --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/page/PageAfterHandler.java @@ -0,0 +1,31 @@ +package com.ruoyi.common.handler.sql.page; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.stereotype.Component; + +import com.ruoyi.common.annotation.sql.MybatisHandlerOrder; +import com.ruoyi.common.context.page.PageContextHolder; +import com.ruoyi.common.context.page.model.TableInfo; +import com.ruoyi.common.handler.sql.MybatisAfterHandler; + +@MybatisHandlerOrder(1) +@Component +public class PageAfterHandler implements MybatisAfterHandler { + + @Override + public Object handleObject(Object object) throws Throwable { + if (PageContextHolder.isPage()) { + if (object instanceof List) { + TableInfo tableInfo = new TableInfo<>((List) object); + tableInfo.setTotal(PageContextHolder.getTotal()); + PageContextHolder.clear(); + return tableInfo; + } + return object; + } + return object; + } + +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/page/PagePreHandler.java b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/page/PagePreHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..387eeb42b74e809f46b95f98409982d7abeab36c --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/handler/sql/page/PagePreHandler.java @@ -0,0 +1,142 @@ +package com.ruoyi.common.handler.sql.page; + +import java.lang.reflect.Field; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.ibatis.cache.CacheKey; +import org.apache.ibatis.executor.Executor; +import org.apache.ibatis.mapping.BoundSql; +import org.apache.ibatis.mapping.MappedStatement; +import org.apache.ibatis.mapping.ResultMap; +import org.apache.ibatis.mapping.ResultMapping; +import org.apache.ibatis.session.ResultHandler; +import org.apache.ibatis.session.RowBounds; +import org.springframework.stereotype.Component; +import org.springframework.util.ReflectionUtils; + +import com.ruoyi.common.annotation.sql.MybatisHandlerOrder; +import com.ruoyi.common.context.page.PageContextHolder; +import com.ruoyi.common.context.page.model.PageInfo; +import com.ruoyi.common.handler.sql.MybatisPreHandler; +import com.ruoyi.common.utils.sql.SqlUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.Limit; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectItem; + +@Component +@MybatisHandlerOrder(2) +public class PagePreHandler implements MybatisPreHandler { + + private static final List EMPTY_RESULTMAPPING = new ArrayList(0); + + private static final String SELECT_COUNT_SUFIX = "_SELECT_COUNT"; + private static final Field sqlFiled = ReflectionUtils.findField(BoundSql.class, "sql"); + static { + sqlFiled.setAccessible(true); + } + + @Override + public void preHandle(Executor executor, MappedStatement mappedStatement, Object params, RowBounds rowBounds, + ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws Throwable { + if (PageContextHolder.isPage()) { + String originSql = boundSql.getSql(); + Statement sql = SqlUtil.parseSql(originSql); + if (sql instanceof Select) { + PageInfo pageInfo = PageContextHolder.getPageInfo(); + Statement handleLimit = handleLimit((Select) sql, pageInfo); + Statement countSql = getCountSql((Select) sql); + Long count = getCount(executor, mappedStatement, params, boundSql, rowBounds, resultHandler, + countSql.toString()); + PageContextHolder.setTotal(count); + sqlFiled.set(boundSql, handleLimit.toString()); + cacheKey = executor.createCacheKey(mappedStatement, params, rowBounds, boundSql); + } + } + + } + + private static MappedStatement createCountMappedStatement(MappedStatement ms, String newMsId) { + MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), newMsId, + ms.getSqlSource(), + ms.getSqlCommandType()); + builder.resource(ms.getResource()); + builder.fetchSize(ms.getFetchSize()); + builder.statementType(ms.getStatementType()); + builder.keyGenerator(ms.getKeyGenerator()); + if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) { + StringBuilder keyProperties = new StringBuilder(); + for (String keyProperty : ms.getKeyProperties()) { + keyProperties.append(keyProperty).append(","); + } + keyProperties.delete(keyProperties.length() - 1, keyProperties.length()); + builder.keyProperty(keyProperties.toString()); + } + builder.timeout(ms.getTimeout()); + builder.parameterMap(ms.getParameterMap()); + // count查询返回值int + List resultMaps = new ArrayList(); + ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(), Long.class, + EMPTY_RESULTMAPPING) + .build(); + resultMaps.add(resultMap); + builder.resultMaps(resultMaps); + builder.resultSetType(ms.getResultSetType()); + builder.cache(ms.getCache()); + builder.flushCacheRequired(ms.isFlushCacheRequired()); + builder.useCache(ms.isUseCache()); + return builder.build(); + } + + public static Long getCount(Executor executor, MappedStatement mappedStatement, Object parameter, + BoundSql boundSql, RowBounds rowBounds, ResultHandler resultHandler, String countSql) + throws SQLException { + + Map additionalParameters = boundSql.getAdditionalParameters(); + + BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, + boundSql.getParameterMappings(), parameter); + for (String key : additionalParameters.keySet()) { + countBoundSql.setAdditionalParameter(key, additionalParameters.get(key)); + } + CacheKey countKey = executor.createCacheKey(mappedStatement, parameter, RowBounds.DEFAULT, countBoundSql); + + List query = executor.query( + createCountMappedStatement(mappedStatement, getCountMSId(mappedStatement)), + parameter, RowBounds.DEFAULT, resultHandler, countKey, + countBoundSql); + return (Long) query.get(0); + } + + private static String getCountMSId(MappedStatement mappedStatement) { + return mappedStatement.getId() + SELECT_COUNT_SUFIX; + } + + public static Statement getCountSql(Select select) { + PlainSelect plain = select.getPlainSelect(); + PlainSelect countPlain = new PlainSelect(); + countPlain.setSelectItems(List.of(new SelectItem<>(new Column("COUNT(0)")))); + countPlain.setJoins(plain.getJoins()); + countPlain.setWhere(plain.getWhere()); + countPlain.setFromItem(plain.getFromItem()); + countPlain.setDistinct(plain.getDistinct()); + countPlain.setHaving(plain.getHaving()); + countPlain.setIntoTables(plain.getIntoTables()); + // countPlain.setOrderByElements(plain.getOrderByElements()); + return plain; + } + + private static Statement handleLimit(Select select, PageInfo pageInfo) { + Limit limit = new Limit(); + limit.setRowCount(new Column(pageInfo.getPageSize().toString())); + limit.setOffset(new Column(pageInfo.getOffeset().toString())); + PlainSelect plain = select.getPlainSelect(); + plain.setLimit(limit); + return select; + } + +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/model/JoinTableModel.java b/ruoyi-common/src/main/java/com/ruoyi/common/model/JoinTableModel.java new file mode 100644 index 0000000000000000000000000000000000000000..6a6cb75ae4f178c2bc89d82224361e62aa0e7d47 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/model/JoinTableModel.java @@ -0,0 +1,85 @@ +package com.ruoyi.common.model; + +import com.ruoyi.common.utils.StringUtils; + +public class JoinTableModel { + private String joinTable; + + private String joinTableAlise; + + private String fromTable; + + private String fromTableAlise; + + private String joinTableColumn; + + private String fromTableColumn; + + public String getJoinTable() { + return joinTable; + } + + public void setJoinTable(String joinTable) { + this.joinTable = joinTable; + } + + public String getJoinTableAlise() { + if (StringUtils.isEmpty(this.joinTableAlise)) { + return this.joinTable; + } + return joinTableAlise; + } + + public void setJoinTableAlise(String joinTableAlise) { + + this.joinTableAlise = joinTableAlise; + } + + public String getFromTable() { + return fromTable; + } + + public void setFromTable(String fromTable) { + this.fromTable = fromTable; + } + + public String getFromTableAlise() { + if (StringUtils.isEmpty(this.fromTableAlise)) { + return this.fromTable; + } + return fromTableAlise; + } + + public void setFromTableAlise(String fromTableAlise) { + this.fromTableAlise = fromTableAlise; + } + + public String getJoinTableColumn() { + + return joinTableColumn; + } + + public void setJoinTableColumn(String joinTableColumn) { + this.joinTableColumn = joinTableColumn; + } + + public String getFromTableColumn() { + return fromTableColumn; + } + + public void setFromTableColumn(String fromTableColumn) { + this.fromTableColumn = fromTableColumn; + } + + public String getJoinTableColumnString() { + return this.getJoinTableAlise() + "." + this.joinTableColumn; + } + + public String getFromTableColumnString() { + if (StringUtils.isEmpty(this.getFromTableAlise())) { + return this.fromTableColumn; + } + return this.getFromTableAlise() + "." + this.fromTableColumn; + } + +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/model/WhereModel.java b/ruoyi-common/src/main/java/com/ruoyi/common/model/WhereModel.java new file mode 100644 index 0000000000000000000000000000000000000000..406b1b53aebd87259d39892bc0ad2e0bb4da6dc5 --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/model/WhereModel.java @@ -0,0 +1,67 @@ +package com.ruoyi.common.model; + +import com.ruoyi.common.utils.StringUtils; + +public class WhereModel { + private String whereColumn; + private String table; + private Object value; + private String connectType; + private String method; + + public static final String METHOD_EQUAS = "="; + public static final String METHOD_LIKE = "like"; + public static final String CONNECT_AND = "AND"; + public static final String CONNECT_OR = "OR"; + + public String getWhereColumn() { + return whereColumn; + } + + public void setWhereColumn(String whereColumn) { + this.whereColumn = whereColumn; + } + + public String getTable() { + return table; + } + + public void setTable(String table) { + this.table = table; + } + + public Object getValue() { + return value; + } + + public void setValue(Object value) { + this.value = value; + } + + public String getFullTableColumn() { + if (StringUtils.isEmpty(this.table)) { + return this.whereColumn; + } + return this.table + "." + this.whereColumn; + } + + public String getConnectType() { + return connectType; + } + + public void setConnectType(String connectType) { + this.connectType = connectType; + } + + public String getMethod() { + return method; + } + + public void setMethod(String method) { + this.method = method; + } + + public String getSqlString() { + return String.format(" %s %s %s %s ", this.getConnectType(), this.getFullTableColumn(), this.method, this.value); + } +} \ No newline at end of file diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/utils/DataSecurityUtil.java b/ruoyi-common/src/main/java/com/ruoyi/common/utils/DataSecurityUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..5dd6e23e6857868000ac68f545b378aa7456b1ec --- /dev/null +++ b/ruoyi-common/src/main/java/com/ruoyi/common/utils/DataSecurityUtil.java @@ -0,0 +1,14 @@ +package com.ruoyi.common.utils; + +import com.ruoyi.common.context.dataSecurity.DataSecurityContextHolder; + +public class DataSecurityUtil { + + public static void closeDataSecurity() { + DataSecurityContextHolder.clearCache(); + } + + public static void startDataSecurity() { + DataSecurityContextHolder.startDataSecurity(); + } +} diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/utils/sql/SqlUtil.java b/ruoyi-common/src/main/java/com/ruoyi/common/utils/sql/SqlUtil.java index 2650fb7b2739be941f8c142d3fbdd37997bc404f..e66d1590722c9dca7d732949ceb3c1d866a5bf86 100644 --- a/ruoyi-common/src/main/java/com/ruoyi/common/utils/sql/SqlUtil.java +++ b/ruoyi-common/src/main/java/com/ruoyi/common/utils/sql/SqlUtil.java @@ -1,32 +1,37 @@ package com.ruoyi.common.utils.sql; +import java.io.StringReader; + import com.ruoyi.common.exception.UtilException; import com.ruoyi.common.utils.StringUtils; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.parser.CCJSqlParserManager; +import net.sf.jsqlparser.statement.Statement; + /** * sql操作工具类 * * @author ruoyi */ -public class SqlUtil -{ +public class SqlUtil { /** * 定义常用的 sql关键字 */ - public static String SQL_REGEX = "and |extractvalue|updatexml|sleep|exec |insert |select |delete |update |drop |count |chr |mid |master |truncate |char |declare |or |union |like |+|/*|user()"; + public static String SQL_REGEX = "and |extractvalue|updatexml|exec |insert |select |delete |update |drop |count |chr |mid |master |truncate |char |declare |or |+|user()"; /** * 仅支持字母、数字、下划线、空格、逗号、小数点(支持多个字段排序) */ public static String SQL_PATTERN = "[a-zA-Z0-9_\\ \\,\\.]+"; + private static final CCJSqlParserManager parserManager = new CCJSqlParserManager(); + /** * 检查字符,防止注入绕过 */ - public static String escapeOrderBySql(String value) - { - if (StringUtils.isNotEmpty(value) && !isValidOrderBySql(value)) - { + public static String escapeOrderBySql(String value) { + if (StringUtils.isNotEmpty(value) && !isValidOrderBySql(value)) { throw new UtilException("参数不符合规范,不能进行查询"); } return value; @@ -35,27 +40,26 @@ public class SqlUtil /** * 验证 order by 语法是否符合规范 */ - public static boolean isValidOrderBySql(String value) - { + public static boolean isValidOrderBySql(String value) { return value.matches(SQL_PATTERN); } /** * SQL关键字检查 */ - public static void filterKeyword(String value) - { - if (StringUtils.isEmpty(value)) - { + public static void filterKeyword(String value) { + if (StringUtils.isEmpty(value)) { return; } String[] sqlKeywords = StringUtils.split(SQL_REGEX, "\\|"); - for (String sqlKeyword : sqlKeywords) - { - if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1) - { + for (String sqlKeyword : sqlKeywords) { + if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1) { throw new UtilException("参数存在SQL注入风险"); } } } + + public static Statement parseSql(String sql) throws JSQLParserException { + return parserManager.parse(new StringReader(sql)); + } } diff --git a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataScopeAspect.java b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataScopeAspect.java index 4866373a9c01c3431038c2740e39444df35a8376..1bc2f69349a9d1fa10e5df32b23b264415cad45b 100644 --- a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataScopeAspect.java +++ b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataScopeAspect.java @@ -2,12 +2,10 @@ package com.ruoyi.framework.aspectj; import java.util.ArrayList; import java.util.List; - import org.aspectj.lang.JoinPoint; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Before; import org.springframework.stereotype.Component; - import com.ruoyi.common.annotation.DataScope; import com.ruoyi.common.core.domain.BaseEntity; import com.ruoyi.common.core.domain.entity.SysRole; @@ -25,7 +23,8 @@ import com.ruoyi.framework.security.context.PermissionContextHolder; */ @Aspect @Component -public class DataScopeAspect { +public class DataScopeAspect +{ /** * 全部数据权限 */ @@ -57,20 +56,23 @@ public class DataScopeAspect { public static final String DATA_SCOPE = "dataScope"; @Before("@annotation(controllerDataScope)") - public void doBefore(JoinPoint point, DataScope controllerDataScope) throws Throwable { + public void doBefore(JoinPoint point, DataScope controllerDataScope) throws Throwable + { clearDataScope(point); handleDataScope(point, controllerDataScope); } - protected void handleDataScope(final JoinPoint joinPoint, DataScope controllerDataScope) { + protected void handleDataScope(final JoinPoint joinPoint, DataScope controllerDataScope) + { // 获取当前的用户 LoginUser loginUser = SecurityUtils.getLoginUser(); - if (StringUtils.isNotNull(loginUser)) { + if (StringUtils.isNotNull(loginUser)) + { SysUser currentUser = loginUser.getUser(); // 如果是超级管理员,则不过滤数据 - if (StringUtils.isNotNull(currentUser) && !currentUser.isAdmin()) { - String permission = StringUtils.defaultIfEmpty(controllerDataScope.permission(), - PermissionContextHolder.getContext()); + if (StringUtils.isNotNull(currentUser) && !currentUser.isAdmin()) + { + String permission = StringUtils.defaultIfEmpty(controllerDataScope.permission(), PermissionContextHolder.getContext()); dataScopeFilter(joinPoint, currentUser, controllerDataScope.deptAlias(), controllerDataScope.userAlias(), permission); } @@ -80,57 +82,59 @@ public class DataScopeAspect { /** * 数据范围过滤 * - * @param joinPoint 切点 - * @param user 用户 - * @param deptAlias 部门别名 - * @param userAlias 用户别名 + * @param joinPoint 切点 + * @param user 用户 + * @param deptAlias 部门别名 + * @param userAlias 用户别名 * @param permission 权限字符 */ - public static void dataScopeFilter(JoinPoint joinPoint, SysUser user, String deptAlias, String userAlias, - String permission) { + public static void dataScopeFilter(JoinPoint joinPoint, SysUser user, String deptAlias, String userAlias, String permission) + { StringBuilder sqlString = new StringBuilder(); List conditions = new ArrayList(); - List scopeCustomIds = new ArrayList(); - user.getRoles().forEach(role -> { - if (DATA_SCOPE_CUSTOM.equals(role.getDataScope()) - && StringUtils.containsAny(role.getPermissions(), Convert.toStrArray(permission))) { - scopeCustomIds.add(Convert.toStr(role.getRoleId())); - } - }); - for (SysRole role : user.getRoles()) { + for (SysRole role : user.getRoles()) + { String dataScope = role.getDataScope(); - if (conditions.contains(dataScope)) { + if (!DATA_SCOPE_CUSTOM.equals(dataScope) && conditions.contains(dataScope)) + { continue; } - if (!StringUtils.containsAny(role.getPermissions(), Convert.toStrArray(permission))) { + if (StringUtils.isNotEmpty(permission) && StringUtils.isNotEmpty(role.getPermissions()) + && !StringUtils.containsAny(role.getPermissions(), Convert.toStrArray(permission))) + { continue; } - if (DATA_SCOPE_ALL.equals(dataScope)) { + if (DATA_SCOPE_ALL.equals(dataScope)) + { sqlString = new StringBuilder(); conditions.add(dataScope); break; - } else if (DATA_SCOPE_CUSTOM.equals(dataScope)) { - if (scopeCustomIds.size() > 1) { - // 多个自定数据权限使用in查询,避免多次拼接。 - sqlString.append(StringUtils.format( - " OR {}.dept_id IN ( SELECT dept_id FROM sys_role_dept WHERE role_id in ({}) ) ", deptAlias, - String.join(",", scopeCustomIds))); - } else { - sqlString.append(StringUtils.format( - " OR {}.dept_id IN ( SELECT dept_id FROM sys_role_dept WHERE role_id = {} ) ", deptAlias, - role.getRoleId())); - } - } else if (DATA_SCOPE_DEPT.equals(dataScope)) { + } + else if (DATA_SCOPE_CUSTOM.equals(dataScope)) + { + sqlString.append(StringUtils.format( + " OR {}.dept_id IN ( SELECT dept_id FROM sys_role_dept WHERE role_id = {} ) ", deptAlias, + role.getRoleId())); + } + else if (DATA_SCOPE_DEPT.equals(dataScope)) + { sqlString.append(StringUtils.format(" OR {}.dept_id = {} ", deptAlias, user.getDeptId())); - } else if (DATA_SCOPE_DEPT_AND_CHILD.equals(dataScope)) { + } + else if (DATA_SCOPE_DEPT_AND_CHILD.equals(dataScope)) + { sqlString.append(StringUtils.format( " OR {}.dept_id IN ( SELECT dept_id FROM sys_dept WHERE dept_id = {} or find_in_set( {} , ancestors ) )", deptAlias, user.getDeptId(), user.getDeptId())); - } else if (DATA_SCOPE_SELF.equals(dataScope)) { - if (StringUtils.isNotBlank(userAlias)) { + } + else if (DATA_SCOPE_SELF.equals(dataScope)) + { + if (StringUtils.isNotBlank(userAlias)) + { sqlString.append(StringUtils.format(" OR {}.user_id = {} ", userAlias, user.getUserId())); - } else { + } + else + { // 数据权限为仅本人且没有userAlias别名不查询任何数据 sqlString.append(StringUtils.format(" OR {}.dept_id = 0 ", deptAlias)); } @@ -139,13 +143,16 @@ public class DataScopeAspect { } // 多角色情况下,所有角色都不包含传递过来的权限字符,这个时候sqlString也会为空,所以要限制一下,不查询任何数据 - if (StringUtils.isEmpty(conditions)) { + if (StringUtils.isEmpty(conditions)) + { sqlString.append(StringUtils.format(" OR {}.dept_id = 0 ", deptAlias)); } - if (StringUtils.isNotBlank(sqlString.toString())) { + if (StringUtils.isNotBlank(sqlString.toString())) + { Object params = joinPoint.getArgs()[0]; - if (StringUtils.isNotNull(params) && params instanceof BaseEntity) { + if (StringUtils.isNotNull(params) && params instanceof BaseEntity) + { BaseEntity baseEntity = (BaseEntity) params; baseEntity.getParams().put(DATA_SCOPE, " AND (" + sqlString.substring(4) + ")"); } @@ -155,9 +162,11 @@ public class DataScopeAspect { /** * 拼接权限sql前先清空params.dataScope参数防止注入 */ - private void clearDataScope(final JoinPoint joinPoint) { + private void clearDataScope(final JoinPoint joinPoint) + { Object params = joinPoint.getArgs()[0]; - if (StringUtils.isNotNull(params) && params instanceof BaseEntity) { + if (StringUtils.isNotNull(params) && params instanceof BaseEntity) + { BaseEntity baseEntity = (BaseEntity) params; baseEntity.getParams().put(DATA_SCOPE, ""); } diff --git a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSecurityAspect.java b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSecurityAspect.java new file mode 100644 index 0000000000000000000000000000000000000000..da4f3887314d61a207e9ba4860555ce2adf89197 --- /dev/null +++ b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSecurityAspect.java @@ -0,0 +1,87 @@ +package com.ruoyi.framework.aspectj; + +import java.util.List; + +import org.aspectj.lang.JoinPoint; +import org.aspectj.lang.annotation.After; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Before; +import org.aspectj.lang.annotation.Pointcut; +import org.springframework.stereotype.Component; + +import com.ruoyi.common.annotation.sql.DataSecurity; +import com.ruoyi.common.context.dataSecurity.DataSecurityContextHolder; +import com.ruoyi.common.enums.DataSecurityStrategy; +import com.ruoyi.common.model.JoinTableModel; +import com.ruoyi.common.model.WhereModel; +import com.ruoyi.common.utils.SecurityUtils; +import com.ruoyi.common.utils.StringUtils; + +import ch.qos.logback.core.util.StringUtil; + +@Aspect +@Component +public class DataSecurityAspect { + + @Before(value = "@annotation(dataSecurity)") + public void doBefore(final JoinPoint point, DataSecurity dataSecurity) throws Throwable { + DataSecurityContextHolder.startDataSecurity(); + switch (dataSecurity.strategy()) { + case CREEATE_BY: + WhereModel createByModel = new WhereModel(); + createByModel.setTable(dataSecurity.table()); + createByModel.setValue("\"" + SecurityUtils.getUsername() + "\""); + createByModel.setWhereColumn("create_by"); + createByModel.setMethod(WhereModel.METHOD_EQUAS); + createByModel.setConnectType(WhereModel.CONNECT_AND); + DataSecurityContextHolder.addWhereParam(createByModel); + break; + case USER_ID: + WhereModel userIdModel = new WhereModel(); + userIdModel.setTable(dataSecurity.table()); + userIdModel.setTable("user_id"); + userIdModel.setValue(SecurityUtils.getUserId()); + userIdModel.setConnectType(WhereModel.CONNECT_AND); + userIdModel.setMethod(WhereModel.METHOD_EQUAS); + DataSecurityContextHolder.addWhereParam(userIdModel); + break; + case JOINTABLE_CREATE_BY: + JoinTableModel createByTableModel = new JoinTableModel(); + createByTableModel.setFromTable(dataSecurity.table()); + createByTableModel.setFromTableAlise(dataSecurity.table()); + createByTableModel.setJoinTable("sys_user"); + if (!StringUtils.isEmpty(dataSecurity.joinTableAlise())) { + createByTableModel.setJoinTableAlise(dataSecurity.joinTableAlise()); + } + + createByTableModel.setFromTableColumn("create_by"); + createByTableModel.setJoinTableColumn("user_name"); + DataSecurityContextHolder.addJoinTable(createByTableModel); + break; + case JOINTABLE_USER_ID: + JoinTableModel userIdTableModel = new JoinTableModel(); + userIdTableModel.setFromTable(dataSecurity.table()); + userIdTableModel.setFromTableAlise(dataSecurity.table()); + userIdTableModel.setJoinTable("sys_user"); + if (!StringUtils.isEmpty(dataSecurity.joinTableAlise())) { + userIdTableModel.setJoinTableAlise(dataSecurity.joinTableAlise()); + } + + userIdTableModel.setFromTableColumn("user_id"); + userIdTableModel.setJoinTableColumn("user_id"); + DataSecurityContextHolder.addJoinTable(userIdTableModel); + break; + + default: + break; + } + + } + + @After(value = " @annotation(dataSecurity)") + public void doAfter(final JoinPoint point, DataSecurity dataSecurity) { + DataSecurityContextHolder.clearCache(); + + } + +} diff --git a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSourceAspect.java b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSourceAspect.java index f551e67fbcc17a7082c8e6d522956330a4ea6bef..8c2c9f4385732249ebd98b9a793352358d2dd7db 100644 --- a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSourceAspect.java +++ b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/DataSourceAspect.java @@ -1,7 +1,6 @@ package com.ruoyi.framework.aspectj; import java.util.Objects; - import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; @@ -12,7 +11,6 @@ import org.slf4j.LoggerFactory; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; - import com.ruoyi.common.annotation.DataSource; import com.ruoyi.common.utils.StringUtils; import com.ruoyi.framework.datasource.DynamicDataSourceContextHolder; @@ -25,31 +23,33 @@ import com.ruoyi.framework.datasource.DynamicDataSourceContextHolder; @Aspect @Order(1) @Component -public class DataSourceAspect { +public class DataSourceAspect +{ protected Logger logger = LoggerFactory.getLogger(getClass()); @Pointcut("@annotation(com.ruoyi.common.annotation.DataSource)" + "|| @within(com.ruoyi.common.annotation.DataSource)") - public void dsPointCut() { + public void dsPointCut() + { } @Around("dsPointCut()") - public Object around(ProceedingJoinPoint point) throws Throwable { + public Object around(ProceedingJoinPoint point) throws Throwable + { DataSource dataSource = getDataSource(point); - if (StringUtils.isNotNull(dataSource)) { - if ("".equals(dataSource.name())) { - DynamicDataSourceContextHolder.setDataSourceType(dataSource.value().name()); - } else { - DynamicDataSourceContextHolder.setDataSourceType(dataSource.name()); - } - + if (StringUtils.isNotNull(dataSource)) + { + DynamicDataSourceContextHolder.setDataSourceType(dataSource.value().name()); } - try { + try + { return point.proceed(); - } finally { + } + finally + { // 销毁数据源 在执行方法之后 DynamicDataSourceContextHolder.clearDataSourceType(); } @@ -58,10 +58,12 @@ public class DataSourceAspect { /** * 获取需要切换的数据源 */ - public DataSource getDataSource(ProceedingJoinPoint point) { + public DataSource getDataSource(ProceedingJoinPoint point) + { MethodSignature signature = (MethodSignature) point.getSignature(); DataSource dataSource = AnnotationUtils.findAnnotation(signature.getMethod(), DataSource.class); - if (Objects.nonNull(dataSource)) { + if (Objects.nonNull(dataSource)) + { return dataSource; } diff --git a/ruoyi-framework/src/main/java/com/ruoyi/framework/interceptor/mybatis/MybatisInterceptor.java b/ruoyi-framework/src/main/java/com/ruoyi/framework/interceptor/mybatis/MybatisInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..caf4af5e5f993c4fd8f58d0ce60d30d2433c9c95 --- /dev/null +++ b/ruoyi-framework/src/main/java/com/ruoyi/framework/interceptor/mybatis/MybatisInterceptor.java @@ -0,0 +1,129 @@ +package com.ruoyi.framework.interceptor.mybatis; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.ibatis.cache.CacheKey; +import org.apache.ibatis.executor.Executor; +import org.apache.ibatis.mapping.BoundSql; +import org.apache.ibatis.mapping.MappedStatement; +import org.apache.ibatis.plugin.Interceptor; +import org.apache.ibatis.plugin.Intercepts; +import org.apache.ibatis.plugin.Invocation; +import org.apache.ibatis.plugin.Signature; +import org.apache.ibatis.session.ResultHandler; +import org.apache.ibatis.session.RowBounds; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import com.ruoyi.common.annotation.sql.MybatisHandlerOrder; +import com.ruoyi.common.handler.sql.MybatisAfterHandler; +import com.ruoyi.common.handler.sql.MybatisPreHandler; + +import jakarta.annotation.PostConstruct; + +@Component +@Intercepts({ + @Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class, + RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class }), + @Signature(type = Executor.class, method = "query", args = { + MappedStatement.class, Object.class, RowBounds.class, + ResultHandler.class }) + +}) +public class MybatisInterceptor implements Interceptor { + + @Autowired + private List preHandlerBeans; + + @Autowired + private List afterHandlerBeans; + + private static List preHandlersChain; + + private static List afterHandlersChain; + + @PostConstruct + public void init() { + List sortedPreHandlers = preHandlerBeans.stream().sorted((item1, item2) -> { + int a; + int b; + MybatisHandlerOrder ann1 = item1.getClass().getAnnotation(MybatisHandlerOrder.class); + MybatisHandlerOrder ann2 = item2.getClass().getAnnotation(MybatisHandlerOrder.class); + if (ann1 == null) { + a = 0; + } else { + a = ann1.value(); + } + if (ann2 == null) { + b = 0; + } else { + b = ann2.value(); + } + return a - b; + }).collect(Collectors.toList()); + preHandlersChain = sortedPreHandlers; + + List sortedAfterHandlers = afterHandlerBeans.stream().sorted((item1, item2) -> { + int a; + int b; + MybatisHandlerOrder ann1 = item1.getClass().getAnnotation(MybatisHandlerOrder.class); + MybatisHandlerOrder ann2 = item2.getClass().getAnnotation(MybatisHandlerOrder.class); + if (ann1 == null) { + a = 0; + } else { + a = ann1.value(); + } + if (ann2 == null) { + b = 0; + } else { + b = ann2.value(); + } + return a - b; + }).collect(Collectors.toList()); + afterHandlersChain = sortedAfterHandlers; + } + + @Override + public Object intercept(Invocation invocation) throws Throwable { + Executor targetExecutor = (Executor) invocation.getTarget(); + Object[] args = invocation.getArgs(); + if (args.length < 6) { + if (preHandlersChain != null && preHandlersChain.size() > 0) { + MappedStatement ms = (MappedStatement) args[0]; + Object parameterObject = args[1]; + RowBounds rowBounds = (RowBounds) args[2]; + Executor executor = (Executor) invocation.getTarget(); + BoundSql boundSql = ms.getBoundSql(parameterObject); + // 可以对参数做各种处理 + CacheKey cacheKey = executor.createCacheKey(ms, parameterObject, rowBounds, boundSql); + for (MybatisPreHandler item : preHandlersChain) { + item.preHandle(targetExecutor, ms, args[1], (RowBounds) args[2], + (ResultHandler) args[3], cacheKey, boundSql); + } + } + Object result = invocation.proceed(); + if (afterHandlersChain != null && afterHandlersChain.size() > 0) { + for (MybatisAfterHandler item : afterHandlersChain) { + item.handleObject(result); + } + } + return result; + } + if (preHandlersChain != null && preHandlersChain.size() > 0) { + for (MybatisPreHandler item : preHandlersChain) { + item.preHandle(targetExecutor, (MappedStatement) args[0], args[1], (RowBounds) args[2], + (ResultHandler) args[3], (CacheKey) args[4], (BoundSql) args[5]); + } + } + Object result = invocation.proceed(); + if (afterHandlersChain != null && afterHandlersChain.size() > 0) { + for (MybatisAfterHandler item : afterHandlersChain) { + result = item.handleObject(result); + } + } + return result; + } + +}