背景:隨著業務的發展,我們同一套業務系統需支持提供給多個客戶(不同的企業用戶)使用,所以需確保在多用戶環境下,各用戶間數據的隔離。但目前系統在早期設計的時候沒有考慮到多租戶的情況,業務數據沒有做到充分隔離(有些表做了字段區分,有些沒有)。
目前數據訪問層用的是MyBatis框架,sql語句散布在xml里,dao注解里,量非常大。另外,租戶字段(companyId)定義也不是所有的業務實體類都有。
基于現狀,一個個修改sql,這樣工作量太大了,所以排除掉一個個修改sql的方案。只能考慮怎樣可以統一修改sql。而租戶字段(companyId)的傳遞也需要有統一處理的地方。
一、業務表添加數據隔字段
我們先給沒有租戶字段(companyId)的表加上字段。然后考慮怎樣給字段統一添加值的改造。因為業務系統目前是使用Mybatis做持久化,Mybatis有攔截器的功能,是否可以通過自定義Mybatis攔截器攔截下所有的 sql 語句,然后對其進行動態修改,自動添加company_id 字段及其字段值,實現數據隔離呢?答案是肯定的。
二、添加Mybatis攔截器
先看下Mybatis的核心對象:
Mybatis核心對象 |
解釋 |
SqlSession |
作為MyBatis工作的主要頂層API,表示和數據庫交互的會話,完成必要數據庫增刪改查功能。 |
Executor |
MyBatis執行器,是MyBatis 調度的核心,負責SQL語句的生成和查詢緩存的維護。 |
StatementHandler |
封裝了JDBC Statement操作,負責對JDBC statement 的操作,如設置參數、將Statement結果集轉換成List集合。 |
ParameterHandler |
負責對用戶傳遞的參數轉換成JDBC Statement 所需要的參數。 |
ResultSetHandler |
負責將JDBC返回的ResultSet結果集對象轉換成List類型的集合。 |
TypeHandler |
負責JAVA數據類型和jdbc數據類型之間的映射和轉換。 |
MAppedStatement |
MappedStatement維護了一條mapper.xml文件里面 select 、update、delete、insert節點的封裝。 |
SqlSource |
負責根據用戶傳遞的parameterObject,動態地生成SQL語句,將信息封裝到BoundSql對象中。 |
BoundSql |
表示動態生成的SQL語句以及相應的參數信息。 |
Configuration |
MyBatis所有的配置信息都維持在Configuration對象。 |
Mybatis攔截器可以攔截Executor、ParameterHandler、StatementHandler、ResultSetHandler四個對象里面的方法。Executor是Mybatis的核心接口。Mybatis中所有的Mapper語句的執行都是通過Executor進行的。其中增刪改語句是通過Executor接口的update方法,查詢語句是通過query方法。所以我們可以攔截Executor,攔載所有的select 、insert、update、delete語句進行改造,添加company_id字段及字段值。
創建一個自定義的攔截器:
/**
* Mybatis - 通用攔截器。用于攔截sql并自動補充公共字段。包括query、insert、update、delete語句
*/
@Slf4j
@Intercepts(
{
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
}
)
public class AutoFillParamInterceptor implements Interceptor {
private static final String LAST_INSERT_ID_SQL = "LAST_INSERT_ID()";
private static final String COMPANY_ID = "company_id";
/**
* 攔截主要的邏輯
* @param invocation
* @return
* @throws Throwable
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
final Object[] args = invocation.getArgs();
final MappedStatement ms = (MappedStatement) args[0];
final Object paramObj = args[1];
// 1.通過注解判斷是否需要處理此SQL
String namespace = ms.getId();
String className = namespace.substring(0, namespace.lastIndexOf("."));
//selectByExample
String methodName = StringUtils.substringAfterLast(namespace, ".");
Class<?> classType = Class.forName(className);
if (classType.isAnnotationPresent(IgnoreAutoFill.class)) {
//注解在類上
String userType = classType.getAnnotation(IgnoreAutoFill.class).userType();
if (StringUtils.isNotBlank(userType)) {
//ignore特定的用戶類型,其他均攔截
if (userType.equals(getCurrentUserType())) {
return invocation.proceed();
}
} else {
return invocation.proceed();
}
} else {
//注解在方法上
for (Method method : classType.getMethods()) {
if (!methodName.equals(method.getName())) {
continue;
} else {
if (method.isAnnotationPresent(IgnoreAutoFill.class)) {
String userType = method.getAnnotation(IgnoreAutoFill.class).userType();
if (StringUtils.isNotBlank(userType)) {
//ignore特定的用戶類型,其他均攔截
if (userType.equals(getCurrentUserType())) {
return invocation.proceed();
}
} else {
return invocation.proceed();
}
}
break;
}
}
}
// 2.獲取SQL語句
BoundSql boundSql = ms.getBoundSql(paramObj);
// 原始sql
String originalSql = boundSql.getSql();
log.debug("originalSql:{}", originalSql);
// 3.根據語句類型改造SQL語句
switch (ms.getSqlCommandType()) {
case INSERT: {
originalSql = convertInsertSQL(originalSql);
args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);
break;
}
case UPDATE:
case DELETE: {
originalSql = SQLUtils.addCondition(originalSql, COMPANY_ID + "='" + getCompanyId() +"'", null);
args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);
break;
}
case SELECT: {
if (!StringUtils.containsIgnoreCase(originalSql, LAST_INSERT_ID_SQL)) {
//where 條件拼接 companyId
MySQLStatementParser parser = new MySqlStatementParser(originalSql);
SQLStatement statement = parser.parseStatement();
SQLSelectStatement selectStatement = (SQLSelectStatement) statement;
SQLSelect sqlSelect = selectStatement.getSelect();
SQLSelectQuery query = sqlSelect.getQuery();
addSelectCondition(query, COMPANY_ID + "='" + getCompanyId() + "'");
originalSql = SQLUtils.toSQLString(selectStatement, JdbcConstants.MYSQL);
// 將新生成的MappedStatement對象替換到參數列表中
args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);
}
break;
}
}
log.debug("modifiedSql:{}", originalSql);
// 4.應用修改后的SQL語句
return invocation.proceed();
}
private void addSelectCondition(SQLSelectQuery query, String condition){
if (query instanceof SQLUnionQuery) {
SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) query;
addSelectCondition(sqlUnionQuery.getLeft(), condition);
addSelectCondition(sqlUnionQuery.getRight(), condition);
} else if (query instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) query;
SQLTableSource tableSource = selectQueryBlock.getFrom();
String conditionTmp = condition;
String alias = getLeftAlias(tableSource);
if (StringUtils.isNotBlank(alias)) {
//拼接別名
conditionTmp = alias + "." + condition;
}
SQLExpr conditionExpr = SQLUtils.toMySqlExpr(conditionTmp);
selectQueryBlock.addCondition(conditionExpr);
}
}
private String getLeftAlias(SQLTableSource tableSource) {
if (tableSource != null) {
if (tableSource instanceof SQLExprTableSource) {
if (StringUtils.isNotBlank(tableSource.getAlias())) {
return tableSource.getAlias();
}
} else if (tableSource instanceof SQLJoinTableSource) {
SQLJoinTableSource join = (SQLJoinTableSource) tableSource;
return getLeftAlias(join.getLeft());
}
}
return null;
}
/**
* 用于封裝目標對象的,通過該方法我們可以返回目標對象本身,也可以返回一個它的代理
* @param target
* @return
*/
@Override
public Object plugin(Object target) {
//只攔截Executor對象,減少目標被代理的次數
if (target instanceof Executor) {
return Plugin.wrap(target, this);
}
return target;
}
/**
* 注冊當前攔截器的時候可以設置一些屬性
*/
@Override
public void setProperties(Properties properties) {
}
private String convertInsertSQL(String originalSql) {
MySqlStatementParser parser = new MySqlStatementParser(originalSql);
SQLStatement statement = parser.parseStatement();
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
statement.accept(visitor);
MySqlInsertStatement myStatement = (MySqlInsertStatement) statement;
String tableName = myStatement.getTableName().getSimpleName();
List<SQLExpr> columns = myStatement.getColumns();
List<SQLInsertStatement.ValuesClause> vcl = myStatement.getValuesList();
if (columns == null || columns.size() <= 0 || myStatement.getQuery() != null) {
return originalSql;
}
if (!visitor.containsColumn(tableName, COMPANY_ID)) {
SQLExpr columnExpr = SQLUtils.toMySqlExpr(COMPANY_ID);
columns.add(columnExpr);
SQLExpr valuesExpr = SQLUtils.toMySqlExpr("'" + getCompanyId() + "'");
vcl.stream().forEach(v -> v.addValue(valuesExpr));
}
return SQLUtils.toSQLString(myStatement, JdbcConstants.MYSQL);
}
private MappedStatement newMappedStatement(MappedStatement ms, BoundSql boundSql,
String sql, Object parameter){
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(),sql, new ArrayList(boundSql.getParameterMappings()), parameter);
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
return copyFromOriMappedStatement(ms, new WarpBoundSqlSqlSource(newBoundSql));
}
private MappedStatement copyFromOriMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),ms.getId(),newSqlSource,ms.getSqlCommandType());
builder.cache(ms.getCache()).databaseId(ms.getDatabaseId())
.fetchSize(ms.getFetchSize())
.flushCacheRequired(ms.isFlushCacheRequired())
.keyColumn(StringUtils.join(ms.getKeyColumns(), ','))
.keyGenerator(ms.getKeyGenerator())
.keyProperty(StringUtils.join(ms.getKeyProperties(), ','))
.lang(ms.getLang()).parameterMap(ms.getParameterMap())
.resource(ms.getResource()).resultMaps(ms.getResultMaps())
.resultOrdered(ms.isResultOrdered())
.resultSets(StringUtils.join(ms.getResultSets(), ','))
.resultSetType(ms.getResultSetType()).statementType(ms.getStatementType())
.timeout(ms.getTimeout()).useCache(ms.isUseCache());
return builder.build();
}
static class WarpBoundSqlSqlSource implements SqlSource {
private final BoundSql boundSql;
public WarpBoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
public String getCompanyId() {
//先從authenticationFacade取
String companyId = CompanyContext.getCompanyId();
if(StringUtils.isBlank(companyId)){
log.error("Can not get the companyId! {}", companyId);
throw new RuntimeException("Can not get the companyId! " + companyId);
}
return companyId;
}
public String getCurrentUserType() {
//authenticationFacade取
AuthenticationFacade authenticationFacade = ApplicationContextProvider.getBean(AuthenticationFacade.class);
Integer currentUserType = authenticationFacade.getCurrentUserType();
if (currentUserType == null) {
log.error("Can not get the currentUserType! {}", currentUserType);
throw new RuntimeException("Can not get the currentUserType! " + currentUserType);
}
UserTypeEnum userTypeEnum = UserTypeEnum.getByCode(currentUserType);
return userTypeEnum.getUserType();
}
}
雖然大部分sql都需要做條件過濾,但也有些特殊情況某些sql可能不需要過濾companyId條件,所以增加一個注解,如果不需要攔截的sql可以在Mapper類或方法上添加此注解,這樣可以兼容不需要攔截的方法。
添加 IgnoreAutoFill 注解:
/**
* 用于標注在不需要被攔截器處理的SQL上(Mapper類)
*/
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface IgnoreAutoFill {
String userType() default "";
}
Mapper示例:
public interface PostRecordDOMapper {
long countByExample(PostRecordDOExample example);
int deleteByExample(PostRecordDOExample example);
int deleteByPrimaryKey(Long id);
int insert(PostRecordDO record);
int insertSelective(PostRecordDO record);
List<PostRecordDO> selectByExample(PostRecordDOExample example);
@IgnoreAutoFill
List<PostRecordDO> selectByExampleAllCompany(PostRecordDOExample example);
PostRecordDO selectByPrimaryKey(Long id);
int updateByExampleSelective(@Param("record") PostRecordDO record, @Param("example") PostRecordDOExample example);
int updateByExample(@Param("record") PostRecordDO record, @Param("example") PostRecordDOExample example);
int updateByPrimaryKeySelective(PostRecordDO record);
int updateByPrimaryKey(PostRecordDO record);
void batchInsert(@Param("items") List<PostRecordDO> items);
}
在攔截器中,我們使用阿里的druid做sql解析,修改sql。
加入 druid 依賴:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.1.6</version>
</dependency>
攔截修改sql時,對于select、update、delete語句,我們直接添加company_id條件,對于insert語名,先判斷原sql的參數列表里有沒有company_id字段,如果有的話不作處理(說明原來就做了字段隔離),沒有才自動給它添加company_id字段及值。
至此,我們解決了統一修改sql的問題,那還有一個重要問題,填充的字段值從哪里取得呢?因為調用持久層Mapper類方法的入參并不一定帶有租戶字段(companyId)信息過來,有些方法甚至只會傳一個id的參數,像 deleteByPrimaryKey(Long id);selectByPrimaryKey(Long id);即使是傳對象參數,對象實體類也不一定有租戶字段(companyId)。所以如何傳遞租戶字段(companyId)是一個改造難點。
三、多租戶字段值的傳遞
考慮一翻,我們是否可以用 ThreadLocal 來存取呢?答案是肯定的。
要傳遞多租戶字段(companyId)值,得先取得companyId值。因為每一個系統用戶都有所屬的companyId,所以只要在用戶登錄系統的時候,從token中拿到用戶所屬的companyId,然后set進ThreadLocal。后續線程的處理都可以從ThreadLocal中取得companyId。這樣Mybatis攔截器也就隨時都可以取得companyId的值進行sql參數或者條件的拼接改造了。
多租戶上下文信息:
@Slf4j
public class CompanyContext implements AutoCloseable {
private static final TransmittableThreadLocal<String> COMPANY_ID_CTX = new TransmittableThreadLocal<>();
public CompanyContext(String companyId) {
COMPANY_ID_CTX.remove();
COMPANY_ID_CTX.set(companyId);
}
public static String getCompanyId(){
return COMPANY_ID_CTX.get();
}
@Override
public void close() throws Exception {
COMPANY_ID_CTX.remove();
}
public static void remove(){
COMPANY_ID_CTX.remove();
}
}
但是,系統的業務處理不可能只用一個線程從頭處理到結束,很多時候為了加快業務的處理,都是需要用到線程池的。
那么,問題又來了,不同線程間如何將這個companyId的ThreadLocal值傳遞下去呢?
這也是有解決方案的。
Transmittable ThreadLocal
Alibaba 有一個 Transmittable ThreadLocal 庫,提供了一個TransmittableThreadLocal,它是 ThreadLocal 的一個擴展,提供了將變量的值從一個線程傳遞到另一個線程的能力。當一個任務被提交到線程池時,TransmittableThreadLocal 變量的值被捕獲并傳遞給執行任務的工作線程。這確保了正確的值在工作線程中可用,即使它最初在不同的線程中設置。
使用Transmittable ThreadLocal 庫,需引入依賴:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>transmittable-thread-local</artifactId>
<version>2.11.5</version>
</dependency>
使用的時候,調用一下TtlExecutors工具提供的getTtlExecutor靜態方法,傳入一個Executor,即可獲取一個支持 TTL (TransmittableThreadLocal)傳遞的 Executor 實例,此線程池就確保了上下文信息的正確傳遞,可放心使用了,如下所示:
@Bean(name = "exportDataExecutorPool")
public Executor exportDataExecutorPool() {
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(CPU_NUM);
threadPoolTaskExecutor.setMaxPoolSize(CPU_NUM * 2);
threadPoolTaskExecutor.setKeepAliveSeconds(60);
threadPoolTaskExecutor.setQueueCapacity(100);
threadPoolTaskExecutor.setThreadNamePrefix("ExportData Thread-");
threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
threadPoolTaskExecutor.initialize();
return TtlExecutors.getTtlExecutor(threadPoolTaskExecutor);
}
這樣就可以確保線程池的線程隨時可以都取到正確的companyId了。
至此,是不是就完成了改造了呢?
還沒有。
為什么呢?
如果是同一個JVM確實是沒問題了,如果不同的JVM呢?
一般較為復雜的系統都會按業務劃分成不同的模塊,同一個模塊也可能部署多個不同的實例,這些不同的模塊或不同的實例間的通信一般是通過遠程調用或者消息隊列進行數據傳遞。那么問題就來了,如何在不同的模塊或實例間傳遞這個companyId呢?
目前我們系統的遠程調用用的是RestTemplate,消息隊列用的Kafka。那就要考慮怎么把companyId統一傳遞出去了。
遠程調用 RestTemplate 的改造
- 對于RestTemplate,發送前我們可以通過ClientHttpRequestInterceptor攔截器,統一把companyId放進header。
@Slf4j
public class BearerTokenHeaderInterceptor implements ClientHttpRequestInterceptor {
public BearerTokenHeaderInterceptor() {
}
@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body,
ClientHttpRequestExecution execution) throws IOException {
//通過攔截器統一把companyId放到header
String companyId = CompanyContext.getCompanyId();
log.info("companyId={}", companyId);
if (!StringUtils.isEmpty(companyId)) {
request.getHeaders().set("companyId", companyId);
}
return execution.execute(request, body);
}
}
注意創建 RestTemplate 時需要把這個攔截器加進去:
@Bean
@LoadBalanced
public RestTemplate restTemplate(RestTemplateBuilder restTemplateBuilder) {
final RestTemplate restTemplate = restTemplateBuilder
.setConnectTimeout(Duration.ofMillis(getConnectTimeout()))
.setReadTimeout(Duration.ofMillis(getReadTimeout()))
.requestFactory(()->httpRequestFactory())
.build();
List<ClientHttpRequestInterceptor> interceptors = restTemplate.getInterceptors();
if (interceptors == null) {
interceptors = Collections.emptyList();
}
interceptors = new ArrayList<>(interceptors);
interceptors.removeIf(BearerTokenHeaderInterceptor.class::isInstance);
interceptors.add(new BearerTokenHeaderInterceptor());
restTemplate.setInterceptors(interceptors);
return restTemplate;
}
- 接收的地方也通過攔截器從header取得companyId并設置到本地變量:
@Slf4j
public class TokenParseAndLoginFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
String accessToken = null;
String companyId = null;
try {
//從header取得并設置companyId本地變量
companyId = request.getHeader("companyId");
new CompanyContext(companyId);
filterChain.doFilter(request, response);
} catch (Exception e) {
log.error("request error:",e);
response.setContentType(MediaType.APPLICATION_JSON_UTF8_VALUE);
response.setStatus(500);
response.getWriter().write(e.getMessage());
response.getWriter().close();
}
}
}
消息隊列 kafka 的改造
- 發送消息的地方,我們統一把companyId放到kafka message header:
/**
* 發送消息
*/
public void sendMsg(String topic, Object value, Map<String, String> headers) {
RecordHeaders kafkaHeaders = new RecordHeaders();
headers.forEach((k,v)->{
RecordHeader recordHeader = new RecordHeader(k,v.getBytes());
kafkaHeaders.add(recordHeader);
});
RecordHeader recordHeader = new RecordHeader("companyId", CompanyContext.getCompanyId().getBytes());
kafkaHeaders.add(recordHeader);
//kafka默認分區
ProducerRecord<String, String> producerRecord = new ProducerRecord<String, String>(topic, null, null, JsonUtil.toJson(value), kafkaHeaders);
kafkaTemplate.send(producerRecord);
}
- 消息消費的地方,我們就可以從kafka message header中拿到companyId設置線程本地變量:
/**
* 獲取實例-手動處理ack
*/
@Bean
public KafkaListenerContainerFactory<ConcurrentMessageListenerContainer<String, String>> kafkaManualAckListenerContainerFactory() {
ConcurrentKafkaListenerContainerFactory<String, String> factory = new ConcurrentKafkaListenerContainerFactory<>();
factory.setConsumerFactory(consumerFactory());
factory.setConcurrency(concurrency);
factory.getContainerProperties().setPollTimeout(3000);
//RetryingAcknowledgingMessageListenerAdapter
factory.getContainerProperties().setAckMode(ContainerProperties.AckMode.MANUAL);
factory.setRetryTemplate(retryTemplate);
factory.setRecoveryCallback(recoveryCallback());
factory.setRecordFilterStrategy(consumerRecord -> {
String companyId = getHead(consumerRecord, "company_id");
// 設置companyId本地變量
new CompanyContext(companyId);
logger.info("Getting the company from kafka message header : {}", companyId);
if(needRequestId) {
String requestId = getHead(consumerRecord, KafkaHeadEnum.REQUEST_ID.getKey());
new RequestIdContext(requestId);
}
return false;
});
return factory;
}
至此,我們就完成了多租戶數據隔離的改造。
四、總結一下,改造的地方:
- 業務表沒有租戶字段(companyId)的,統一加上company_id字段。
- 利用Mybatis攔截器,攔載所有的select 、insert、update、delete語句進行改造SQL,自動添加company_id字段及字段值。
- 利用Transmittable ThreadLocal ,進行companyId值的傳遞。
- 對于http遠程調用的,通過攔截器,發送端統一添加companyId字段到header,接收端通過OncePerRequestFilter從header取得統一設到ThreadLocal。
- 對消息隊列(Kafka),發送端統一處理,添加companyId字段到message header,消費端通過RecordFilter從header取得統一設到ThreadLocal。