日日操夜夜添-日日操影院-日日草夜夜操-日日干干-精品一区二区三区波多野结衣-精品一区二区三区高清免费不卡

公告:魔扣目錄網為廣大站長提供免費收錄網站服務,提交前請做好本站友鏈:【 網站目錄:http://www.ylptlb.cn 】, 免友鏈快審服務(50元/站),

點擊這里在線咨詢客服
新站提交
  • 網站:51998
  • 待審:31
  • 小程序:12
  • 文章:1030137
  • 會員:747

本文將展示如何使用JAX/Flax實現Vision Transformer (ViT),以及如何使用JAX/Flax訓練ViT。

Vision Transformer

在實現Vision Transformer時,首先要記住這張圖。

以下是論文描述的ViT執行過程。

從輸入圖像中提取補丁圖像,并將其轉換為平面向量。

投影到 Transformer Encoder 來處理的維度

預先添加一個可學習的嵌入([class]標記),并添加一個位置嵌入。

由 Transformer Encoder 進行編碼處理

使用[class]令牌作為輸出,輸入到MLP進行分類。

細節實現

下面,我們將使用JAX/Flax創建每個模塊。

1、圖像到展平的圖像補丁

下面的代碼從輸入圖像中提取圖像補丁。這個過程通過卷積來實現,內核大小為patch_size * patch_size, stride為patch_size * patch_size,以避免重復。

class Patches(nn.Module):
patch_size: int
embed_dim: int

def setup(self):
self.conv = nn.Conv(
features=self.embed_dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding='VALID'
)

def __call__(self, images):
patches = self.conv(images)
b, h, w, c = patches.shape
patches = jnp.reshape(patches, (b, h*w, c))
return patches

2和3、對展平補丁塊的線性投影/添加[CLS]標記/位置嵌入

Transformer Encoder 對所有層使用相同的尺寸大小hidden_dim。上面創建的補丁塊向量被投影到hidden_dim維度向量上。與BERT一樣,有一個CLS令牌被添加到序列的開頭,還增加了一個可學習的位置嵌入來保存位置信息。

class PatchEncoder(nn.Module):
hidden_dim: int

@nn.compact
def __call__(self, x):
assert x.ndim == 3
n, seq_len, _ = x.shape
# Hidden dim
x = nn.Dense(self.hidden_dim)(x)
# Add cls token
cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
cls = jnp.tile(cls, (n, 1, 1))
x = jnp.concatenate([cls, x], axis=1)
# Add position embedding
pos_embed = self.param(
'position_embedding',
nn.initializers.normal(stddev=0.02), # From BERT
(1, seq_len + 1, self.hidden_dim)
)
return x + pos_embed

4、Transformer encoder

如上圖所示,編碼器由多頭自注意(MSA)和MLP交替層組成。Norm層 (LN)在MSA和MLP塊之前,殘差連接在塊之后。

class TransformerEncoder(nn.Module):
embed_dim: int
hidden_dim: int
n_heads: int
drop_p: float
mlp_dim: int

def setup(self):
self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
self.mlp = MLP(self.mlp_dim, self.drop_p)
self.layer_norm = nn.LayerNorm(epsilon=1e-6)

def __call__(self, inputs, train=True):
# Attention Block
x = self.layer_norm(inputs)
x = self.mha(x, train)
x = inputs + x
# MLP block
y = self.layer_norm(x)
y = self.mlp(y, train)

return x + y

MLP是一個兩層網絡。激活函數是GELU。本文將Dropout應用于Dense層之后。

class MLP(nn.Module):
mlp_dim: int
drop_p: float
out_dim: Optional[int] = None

@nn.compact
def __call__(self, inputs, train=True):
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(features=self.mlp_dim)(inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
x = nn.Dense(features=actual_out_dim)(x)
x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
return x

多頭自注意(MSA)

qkv的形式應為[B, N, T, D],如Single Head中計算權重和注意力后,應輸出回原維度[B, T, C=N*D]。

class MultiHeadSelfAttention(nn.Module):
hidden_dim: int
n_heads: int
drop_p: float

def setup(self):
self.q.NET = nn.Dense(self.hidden_dim)
self.k_net = nn.Dense(self.hidden_dim)
self.v_net = nn.Dense(self.hidden_dim)

self.proj_net = nn.Dense(self.hidden_dim)

self.att_drop = nn.Dropout(self.drop_p)
self.proj_drop = nn.Dropout(self.drop_p)

def __call__(self, x, train=True):
B, T, C = x.shape # batch_size, seq_length, hidden_dim
N, D = self.n_heads, C // self.n_heads # num_heads, head_dim
q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)

# weights (B, N, T, T)
weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D)
normalized_weights = nn.softmax(weights, axis=-1)

# attention (B, N, T, D)
attention = jnp.matmul(normalized_weights, v)
attention = self.att_drop(attention, deterministic=not train)

# gather heads
attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)

# project
out = self.proj_drop(self.proj_net(attention), deterministic=not train)

return out

5、使用CLS嵌入進行分類

最后MLP頭(分類頭)。

class ViT(nn.Module):
patch_size: int
embed_dim: int
hidden_dim: int
n_heads: int
drop_p: float
num_layers: int
mlp_dim: int
num_classes: int

def setup(self):
self.patch_extracter = Patches(self.patch_size, self.embed_dim)
self.patch_encoder = PatchEncoder(self.hidden_dim)
self.dropout = nn.Dropout(self.drop_p)
self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim)
self.cls_head = nn.Dense(features=self.num_classes)

def __call__(self, x, train=True):
x = self.patch_extracter(x)
x = self.patch_encoder(x)
x = self.dropout(x, deterministic=not train)
for i in range(self.num_layers):
x = self.transformer_encoder(x, train)
# MLP head
x = x[:, 0] # [CLS] token
x = self.cls_head(x)
return x

使用JAX/Flax訓練

現在已經創建了模型,下面就是使用JAX/Flax來訓練。

數據集

這里我們直接使用 torchvision的CIFAR10.

首先是一些工具函數

def image_to_numpy(img):
img = np.array(img, dtype=np.float32)
img = (img / 255. - DATA_MEANS) / DATA_STD
return img

def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)

然后是訓練和測試的dataloader

test_transform = image_to_numpy
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),
image_to_numpy
])

# Validation set should not use the augmentation.
train_dataset = CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
test_set = CIFAR10('data', train=False, transform=test_transform, download=True)

train_loader = torch.utils.data.DataLoader(
train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_loader = torch.utils.data.DataLoader(
val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)

初始化模型

初始化ViT模型

def initialize_model(
seed=42,
patch_size=16, embed_dim=192, hidden_dim=192,
n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10
):
main_rng = jax.random.PRNGKey(seed)
x = jnp.ones(shape=(5, 32, 32, 3))
# ViT
model = ViT(
patch_size=patch_size,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
n_heads=n_heads,
drop_p=drop_p,
num_layers=num_layers,
mlp_dim=mlp_dim,
num_classes=num_classes
)
main_rng, init_rng, drop_rng = random.split(main_rng, 3)
params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']
return model, params, main_rng

vit_model, vit_params, vit_rng = initialize_model()

創建TrainState

在Flax中常見的模式是創建管理訓練的狀態的類,包括輪次、優化器狀態和模型參數等等。還可以通過在Apply_fn中指定apply_fn來減少學習循環中的函數參數列表,apply_fn對應于模型的前向傳播。

def create_train_state(
model, params, learning_rate
):
optimizer = optax.adam(learning_rate)
return train_state.TrainState.create(
apply_fn=model.apply,
tx=optimizer,
params=params
)

state = create_train_state(vit_model, vit_params, 3e-4)

循環訓練

def train_model(train_loader, val_loader, state, rng, num_epochs=100):
best_eval = 0.0
for epoch_idx in tqdm(range(1, num_epochs + 1)):
state, rng = train_epoch(train_loader, epoch_idx, state, rng)
if epoch_idx % 1 == 0:
eval_acc = eval_model(val_loader, state, rng)
logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
if eval_acc >= best_eval:
best_eval = eval_acc
save_model(state, step=epoch_idx)
logger.flush()
# Evaluate after training
test_acc = eval_model(test_loader, state, rng)
print(f'test_acc: {test_acc}')

def train_epoch(train_loader, epoch_idx, state, rng):
metrics = defaultdict(list)
for batch in tqdm(train_loader, desc='Training', leave=False):
state, rng, loss, acc = train_step(state, rng, batch)
metrics['loss'].append(loss)
metrics['acc'].append(acc)
for key in metrics.keys():
arg_val = np.stack(jax.device_get(metrics[key])).mean()
logger.add_scalar('train/' + key, arg_val, global_step=epoch_idx)
print(f'[epoch {epoch_idx}] {key}: {arg_val}')
return state, rng

驗證

def eval_model(data_loader, state, rng):
# Test model on all images of a data loader and return avg loss
correct_class, count = 0, 0
for batch in data_loader:
rng, acc = eval_step(state, rng, batch)
correct_class += acc * batch[0].shape[0]
count += batch[0].shape[0]
eval_acc = (correct_class / count).item()
return eval_acc

訓練步驟

在train_step中定義損失函數,計算模型參數的梯度,并根據梯度更新參數;在value_and_gradients方法中,計算狀態的梯度。在apply_gradients中,更新TrainState。交叉熵損失是通過apply_fn(與model.apply相同)計算logits來計算的,apply_fn是在創建TrainState時指定的。

@jax.jit
def train_step(state, rng, batch):
loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True)
# Get loss, gradients for loss, and other outputs of loss function
(loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
# Update parameters and batch statistics
state = state.apply_gradients(grads=grads)
return state, rng, loss, acc

計算損失

def calculate_loss(params, state, rng, batch, train):
imgs, labels = batch
rng, drop_rng = random.split(rng)
logits = state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
acc = (logits.argmax(axis=-1) == labels).mean()
return loss, (acc, rng)

結果

訓練結果如下所示。在Colab pro的標準GPU上,訓練時間約為1.5小時。

test_acc: 0.7704000473022461

如果你對JAX感興趣,請看這里是本文的完整代碼:

https://avoid.overfit.cn/post/926b7965ba56464ba151cbbfb6a98a93

作者:satojkovic

分享到:
標簽:JAX
用戶無頭像

網友整理

注冊時間:

網站:5 個   小程序:0 個  文章:12 篇

  • 51998

    網站

  • 12

    小程序

  • 1030137

    文章

  • 747

    會員

趕快注冊賬號,推廣您的網站吧!
最新入駐小程序

數獨大挑戰2018-06-03

數獨一種數學游戲,玩家需要根據9

答題星2018-06-03

您可以通過答題星輕松地創建試卷

全階人生考試2018-06-03

各種考試題,題庫,初中,高中,大學四六

運動步數有氧達人2018-06-03

記錄運動步數,積累氧氣值。還可偷

每日養生app2018-06-03

每日養生,天天健康

體育訓練成績評定2018-06-03

通用課目體育訓練成績評定