diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/GMA.patch b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA.patch new file mode 100644 index 0000000000000000000000000000000000000000..968f0492e700c49eb9a67380b89db20908cf1ce0 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA.patch @@ -0,0 +1,541 @@ +diff --git a/core/corr.py b/core/corr.py +index d9744d4..d9dd15d 100644 +--- a/core/corr.py ++++ b/core/corr.py +@@ -26,7 +26,7 @@ class CorrBlock: + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): +- corr = F.avg_pool2d(corr, 2, stride=2) ++ corr = F.avg_pool2d(corr.float(), 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): +@@ -52,9 +52,11 @@ class CorrBlock: + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + ++ + @staticmethod + def corr(fmap1, fmap2): +- batch, dim, ht, wd = fmap1.shape ++ ++ batch, dim, ht, wd = fmap1.shape # fmap1.shape = torch.Size([1, 256, 55, 128]) + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + +@@ -63,6 +65,7 @@ class CorrBlock: + return corr / torch.sqrt(torch.tensor(dim).float()) + + ++ + class CorrBlockSingleScale(nn.Module): + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + super().__init__() +@@ -89,7 +92,8 @@ class CorrBlockSingleScale(nn.Module): + + corr = bilinear_sampler(corr, coords_lvl) + out = corr.view(batch, h1, w1, -1) +- out = out.permute(0, 3, 1, 2).contiguous().float() ++ ++ out = out.permute(0, 3, 1, 2).contiguous().half() + return out + + @staticmethod +@@ -100,4 +104,5 @@ class CorrBlockSingleScale(nn.Module): + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) +- return corr / torch.sqrt(torch.tensor(dim).float()) ++ return corr / torch.sqrt(torch.tensor(dim).half()) ++ +diff --git a/core/datasets.py b/core/datasets.py +index e7f1528..1f01147 100644 +--- a/core/datasets.py ++++ b/core/datasets.py +@@ -36,6 +36,9 @@ class FlowDataset(data.Dataset): + + def __getitem__(self, index): + ++ src_image1_name = self.image_list[index][0] ++ src_image2_name = self.image_list[index][1] ++ + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) +@@ -43,7 +46,7 @@ class FlowDataset(data.Dataset): + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() +- return img1, img2, self.extra_info[index] ++ return src_image1_name, src_image2_name, img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() +@@ -109,11 +112,11 @@ class FlowDataset(data.Dataset): + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + if self.occ_list is not None: +- return img1, img2, flow, valid.float(), occ, self.occ_list[index] ++ return src_image1_name, src_image2_name, img1, img2, flow, valid.float(), occ, self.occ_list[index] + elif self.seg_list is not None and self.seg_inv_list is not None: +- return img1, img2, flow, valid.float(), seg_map, seg_inv ++ return src_image1_name, src_image2_name, img1, img2, flow, valid.float(), seg_map, seg_inv + else: +- return img1, img2, flow, valid.float()#, self.extra_info[index] ++ return src_image1_name, src_image2_name, img1, img2, flow, valid.float()#, self.extra_info[index] + + def __rmul__(self, v): + self.flow_list = v * self.flow_list +@@ -125,18 +128,18 @@ class FlowDataset(data.Dataset): + + + class MpiSintel(FlowDataset): +- def __init__(self, aug_params=None, split='training', root='/home/zac/data/Sintel', dstype='clean', ++ def __init__(self, aug_params=None, split='training', root='/home/lq/Sintel/', dstype='clean', + occlusion=False, segmentation=False): + super(MpiSintel, self).__init__(aug_params) +- flow_root = osp.join(root, split, 'flow') +- image_root = osp.join(root, split, dstype) ++ flow_root = osp.join(root, split, 'flow') # true ++ image_root = osp.join(root, split, dstype) # true + # occ_root = osp.join(root, split, 'occlusions') + # occ_root = osp.join(root, split, 'occ_plus_out') + # occ_root = osp.join(root, split, 'in_frame_occ') +- occ_root = osp.join(root, split, 'out_of_frame') ++ occ_root = osp.join(root, split, 'out_of_frame') # false + +- seg_root = osp.join(root, split, 'segmentation') +- seg_inv_root = osp.join(root, split, 'segmentation_invalid') ++ seg_root = osp.join(root, split, 'segmentation') # false ++ seg_inv_root = osp.join(root, split, 'segmentation_invalid') # false + self.segmentation = segmentation + self.occlusion = occlusion + if self.occlusion: +@@ -145,7 +148,8 @@ class MpiSintel(FlowDataset): + self.seg_list = [] + self.seg_inv_list = [] + +- if split == 'test': ++ # split == 'training' ++ if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): +@@ -266,7 +270,7 @@ class HD1K(FlowDataset): + seq_ix += 1 + + +-def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): ++def fetch_dataloader(args, rank, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding training set """ + + if args.stage == 'chairs': +@@ -297,8 +301,25 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + +- train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, +- pin_memory=True, shuffle=True, num_workers=8, drop_last=True) ++ # train_sampler = torch.utils.data.distributed.DistributedSampler( ++ # train_dataset, ++ # num_replicas=args.world_size, ++ # rank=rank ++ # ) ++ ++ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) ++ ++ ++ train_loader = data.DataLoader( ++ dataset=train_dataset, ++ batch_size=args.batch_size, ++ pin_memory=False, ++ shuffle=False, # ++ num_workers=args.workers, ++ drop_last=True, ++ sampler=train_sampler # ++ ) ++ + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader +diff --git a/core/gma.py b/core/gma.py +index c1c8449..ae8c40f 100644 +--- a/core/gma.py ++++ b/core/gma.py +@@ -1,6 +1,131 @@ + import torch + from torch import nn, einsum + from einops import rearrange ++import pdb ++import traceback ++pdb.set_trace = lambda: 1 ++ ++ ++def einsum0(eq, operands): ++ # 'b h x y d, x u v d -> b h x y u v' ++ ++ tmp = torch.einsum(eq, operands) ++ # if not torch.onnx.is_in_onnx_export(): ++ # return tmp ++ ++ try: ++ x0, y0 = operands ++ assert (len(x0.shape) == 5 and len(y0.shape) == 4) ++ b, h, x, y, d = x0.shape ++ x, u, v, d = y0.shape ++ ++ x_tensor = x0.clone() ++ y_tensor = y0.clone() ++ ++ x_tensor = x_tensor.reshape(b,h,x,y,1,1,d) ++ y_tensor = y_tensor.reshape(1,1,x,1,u,v,d) ++ ++ z_tensor = x_tensor * y_tensor ++ z_tensor = z_tensor.sum(dim=6) ++ ++ assert tmp.equal(z_tensor) ++ return z_tensor ++ except Exception as e: ++ print('str(e):\t\t', str(e)) ++ print('repr(e):\t', repr(e)) ++ print('traceback.print_exc():', traceback.print_exc()) ++ print('traceback.format_exc():\n%s' % traceback.format_exc()) ++ pdb.set_trace() ++ ++def einsum1(eq, operands): ++ # 'b h x y d, y u v d -> b h x y u v' ++ ++ tmp = torch.einsum(eq, operands) ++ # if not torch.onnx.is_in_onnx_export(): ++ # return tmp ++ try: ++ x0, y0 = operands ++ assert (len(x0.shape) == 5 and len(y0.shape) == 4) ++ b, h, x, y, d = x0.shape ++ y, u, v, d = y0.shape ++ ++ x_tensor = x0.clone() ++ y_tensor = y0.clone() ++ ++ x_tensor = x_tensor.reshape(b,h,x,y,1,1,d) ++ y_tensor = y_tensor.reshape(1,1,1,y,u,v,d) ++ ++ z_tensor = x_tensor * y_tensor ++ z_tensor = z_tensor.sum(dim=6) ++ ++ assert tmp.equal(z_tensor) ++ return z_tensor ++ except Exception as e: ++ print('str(e):\t\t', str(e)) ++ print('repr(e):\t', repr(e)) ++ print('traceback.print_exc():', traceback.print_exc()) ++ print('traceback.format_exc():\n%s' % traceback.format_exc()) ++ pdb.set_trace() ++ ++def einsum2(eq, operands): ++ # 'b h x y d, b h u v d -> b h x y u v' ++ tmp = torch.einsum(eq, operands) ++ # if not torch.onnx.is_in_onnx_export(): ++ # return tmp ++ try: ++ x0, y0 = operands ++ assert (len(x0.shape) == 5 and len(y0.shape) == 5) ++ assert (x0.shape[:2] == y0.shape[:2]) ++ b, h, x, y, d = x0.shape ++ b, h, u, v, d = y0.shape ++ ++ x_tensor = x0.clone() ++ y_tensor = y0.clone() ++ ++ x_tensor = x_tensor.reshape(b*h,x*y,d) # [b*h,x*y,d] ++ y_tensor = y_tensor.reshape(b*h,u*v,d).permute(0,2,1) # [b*h,d,u*v] ++ ++ z_tensor = torch.bmm(x_tensor, y_tensor) ++ z_tensor = z_tensor.reshape(b,h,x,y,u,v) ++ ++ ++ assert tmp.equal(z_tensor) ++ return z_tensor ++ except Exception as e: ++ print('str(e):\t\t', str(e)) ++ print('repr(e):\t', repr(e)) ++ print('traceback.print_exc():', traceback.print_exc()) ++ print('traceback.format_exc():\n%s' % traceback.format_exc()) ++ pdb.set_trace() ++ ++def einsum3(eq, operands): ++ # 'b h i j, b h j d -> b h i d' ++ tmp = torch.einsum(eq, operands) ++ # if not torch.onnx.is_in_onnx_export(): ++ # return tmp ++ try: ++ x0, y0 = operands ++ assert (len(x0.shape) == 4 and len(y0.shape) == 4) ++ assert (x0.shape[:2] == y0.shape[:2]) ++ b, h, i, j = x0.shape ++ b, h, j, d = y0.shape ++ ++ x_tensor = x0.clone() ++ y_tensor = y0.clone() ++ ++ x_tensor = x_tensor.reshape(b*h,i,j) ++ y_tensor = y_tensor.reshape(b*h,j,d) ++ z_tensor = torch.bmm(x_tensor,y_tensor) ++ z_tensor = z_tensor.reshape(b,h,i,d) ++ assert tmp.equal(z_tensor) ++ return z_tensor ++ except Exception as e: ++ print('str(e):\t\t', str(e)) ++ print('repr(e):\t', repr(e)) ++ print('traceback.print_exc():', traceback.print_exc()) ++ print('traceback.format_exc():\n%s' % traceback.format_exc()) ++ pdb.set_trace() ++ + + + class RelPosEmb(nn.Module): +@@ -25,8 +150,10 @@ class RelPosEmb(nn.Module): + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + +- height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) +- width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) ++ # height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) ++ height_score = einsum0('b h x y d, x u v d -> b h x y u v', (q, height_emb)) ++ # width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) ++ width_score = einsum1('b h x y d, y u v d -> b h x y u v', (q, width_emb)) + + return height_score + width_score + +@@ -49,7 +176,8 @@ class Attention(nn.Module): + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + +- self.pos_emb = RelPosEmb(max_pos_size, dim_head) ++ if self.args.position_only or self.args.position_and_content: ++ self.pos_emb = RelPosEmb(max_pos_size, dim_head) + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape +@@ -63,12 +191,14 @@ class Attention(nn.Module): + sim = self.pos_emb(q) + + elif self.args.position_and_content: +- sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) ++ # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) ++ sim_content = einsum2('b h x y d, b h u v d -> b h x y u v', (q, k)) + sim_pos = self.pos_emb(q) + sim = sim_content + sim_pos + + else: +- sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) ++ # sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) ++ sim = einsum2('b h x y d, b h u v d -> b h x y u v', (q, k)) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) +@@ -104,7 +234,8 @@ class Aggregate(nn.Module): + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) +- out = einsum('b h i j, b h j d -> b h i d', attn, v) ++ # out = einsum('b h i j, b h j d -> b h i d', attn, v) ++ out = einsum3('b h i j, b h j d -> b h i d', (attn, v)) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + if self.project is not None: +diff --git a/core/network.py b/core/network.py +index 73a89ac..08d30d3 100644 +--- a/core/network.py ++++ b/core/network.py +@@ -87,6 +87,7 @@ class RAFTGMA(nn.Module): + + fmap1 = fmap1.float() + fmap2 = fmap2.float() ++ + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network +@@ -109,8 +110,12 @@ class RAFTGMA(nn.Module): + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 +- with autocast(enabled=self.args.mixed_precision): +- net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) ++ ++ flow = flow.float() ++ ++ ++ # with autocast(enabled=self.args.mixed_precision): ++ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow +diff --git a/core/update.py b/core/update.py +index 91a38b4..ece75d4 100644 +--- a/core/update.py ++++ b/core/update.py +@@ -74,12 +74,17 @@ class BasicMotionEncoder(nn.Module): + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): ++ + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + +- cor_flo = torch.cat([cor, flo], dim=1) ++ cor = cor.half() ++ flo = flo.half() ++ cor_flo = torch.cat([cor, flo], dim=1) # 出现精度问题 ++ ++ cor_flo = cor_flo.float() + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +@@ -98,6 +103,7 @@ class BasicUpdateBlock(nn.Module): + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): ++ + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + +@@ -125,6 +131,8 @@ class GMAUpdateBlock(nn.Module): + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads) + + def forward(self, net, inp, corr, flow, attention): ++ ++ + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) +diff --git a/core/utils/utils.py b/core/utils/utils.py +index f64841b..ae0bce6 100644 +--- a/core/utils/utils.py ++++ b/core/utils/utils.py +@@ -3,7 +3,8 @@ import torch.nn.functional as F + import numpy as np + from scipy import interpolate + # from torch_scatter import scatter_softmax, scatter_add +- ++# import mmcv ++# import mmcv.ops + + class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ +@@ -53,7 +54,87 @@ def forward_interpolate(flow): + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) +- return torch.from_numpy(flow).float() ++ return torch.from_numpy(flow).half() ++ ++ ++ ++ ++ ++ ++def bilinear_grid_sample(im, grid, align_corners=False): ++ """Given an input and a flow-field grid, computes the output using input ++ values and pixel locations from grid. Supported only bilinear interpolation ++ method to sample the input pixels. ++ ++ Args: ++ im (torch.Tensor): Input feature map, shape (N, C, H, W) ++ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) ++ align_corners {bool}: If set to True, the extrema (-1 and 1) are ++ considered as referring to the center points of the input’s ++ corner pixels. If set to False, they are instead considered as ++ referring to the corner points of the input’s corner pixels, ++ making the sampling more resolution agnostic. ++ Returns: ++ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) ++ """ ++ n, c, h, w = im.shape ++ gn, gh, gw, _ = grid.shape ++ assert n == gn ++ ++ x = grid[:, :, :, 0] ++ y = grid[:, :, :, 1] ++ ++ if align_corners: ++ x = ((x + 1) / 2) * (w - 1) ++ y = ((y + 1) / 2) * (h - 1) ++ else: ++ x = ((x + 1) * w - 1) / 2 ++ y = ((y + 1) * h - 1) / 2 ++ ++ x = x.view(n, -1) ++ y = y.view(n, -1) ++ ++ x0 = torch.floor(x).int() ++ y0 = torch.floor(y).int() ++ x1 = x0 + 1 ++ y1 = y0 + 1 ++ wa = ((x1 - x) * (y1 - y)).unsqueeze(1) ++ wb = ((x1 - x) * (y - y0)).unsqueeze(1) ++ wc = ((x - x0) * (y1 - y)).unsqueeze(1) ++ wd = ((x - x0) * (y - y0)).unsqueeze(1) ++ # Apply default for grid_sample function zero padding ++ im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) ++ padded_h = h + 2 ++ padded_w = w + 2 ++ # save points positions after padding ++ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 ++ # Clip coordinates to padded image size ++ x0 = torch.where(x0 < 0, torch.tensor(0,dtype=torch.int32), x0) ++ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1,dtype=torch.int32), x0) ++ x1 = torch.where(x1 < 0, torch.tensor(0,dtype=torch.int32), x1) ++ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1,dtype=torch.int32), x1) ++ y0 = torch.where(y0 < 0, torch.tensor(0,dtype=torch.int32), y0) ++ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1,dtype=torch.int32), y0) ++ y1 = torch.where(y1 < 0, torch.tensor(0,dtype=torch.int32), y1) ++ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1,dtype=torch.int32), y1) ++ ++ im_padded = im_padded.view(n, c, -1) ++ ++ ++ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).type(torch.int32).expand(-1, c, -1) ++ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).type(torch.int32).expand(-1, c, -1) ++ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).type(torch.int32).expand(-1, c, -1) ++ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).type(torch.int32).expand(-1, c, -1) ++ ++ ++ Ia = torch.gather(im_padded, 2, x0_y0.type(torch.int64)) ++ Ib = torch.gather(im_padded, 2, x0_y1.type(torch.int64)) ++ Ic = torch.gather(im_padded, 2, x1_y0.type(torch.int64)) ++ Id = torch.gather(im_padded, 2, x1_y1.type(torch.int64)) ++ ++ ++ ++ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) + + + def bilinear_sampler(img, coords, mode='bilinear', mask=False): +@@ -64,11 +145,12 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) +- img = F.grid_sample(img, grid, align_corners=True) +- ++ # img = F.grid_sample(img, grid, align_corners=True) ++ img = bilinear_grid_sample(img, grid, align_corners=True) + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) +- return img, mask.float() ++ return img, mask.half() ++ + + return img + +@@ -76,6 +158,7 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): + def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() ++ + return coords[None].expand(batch, -1, -1, -1) + + diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_postprocess.py b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e887a9635ccd86cec371a2014db861e3a1ab1363 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_postprocess.py @@ -0,0 +1,160 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import argparse +import os +import os.path as osp +import numpy as np +import torch.nn.functional as F +import time +import onnxruntime as rt + + + +def perf_g(onnx_path, onnx_input_size, loop=20): + + time_cost_list = [] + + for idx in range(loop): + print(f"loop : {idx}") + + tic = time.time() # s + sess = rt.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'], provider_options=[ {'device_id': 0}]) + input_name = [] + for n in sess.get_inputs(): + input_name.append(n.name) + + output_name = [] + for n in sess.get_outputs(): + output_name.append(n.name) + + image_1 = np.random.randn(1,3,onnx_input_size[0], onnx_input_size[1]).astype(np.float32) + image_2 = np.random.randn(1,3,onnx_input_size[0], onnx_input_size[1]).astype(np.float32) + + input_data = [image_1,image_2] + + flow_pr = sess.run(None, {input_name[i]: input_data[i] for i in range(len(input_name))})[-1] + + toc = time.time() # s + time_cost_list.append(toc-tic) + + second_per_sample = sum(time_cost_list) / len(time_cost_list) + fps = 1 / second_per_sample + print(f"t4 bs1 fps:{fps}") + + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def load_statistical_predict_result(predict_result_dir): + + ret = None + for root, dirs, files in os.walk(predict_result_dir): + + if files != []: + path = osp.join(root, files[0]) + ret = np.fromfile(path, dtype=np.float32) + # print(path) + + return ret + + +def postprocess(om_result_path, bin_src_path, info_name): + image_shape_before_pad = (436,1024) + image_shape_after_pad = (440,1024) + + epe_list = [] + with open(info_name, 'r') as f: + while True: + item = f.readline().strip('\n').split(' ') + if (item == ['']): + break + inference_output_path_dir = osp.join(om_result_path,item[3]) + flow_gt_path = item[2] + + padder = InputPadder(image_shape_before_pad) + + try: + flow_pr = load_statistical_predict_result(inference_output_path_dir).reshape(1,2,image_shape_after_pad[0],image_shape_after_pad[1]) + flow_gt = np.fromfile(flow_gt_path,dtype=np.float32).reshape(2,image_shape_before_pad[0], image_shape_before_pad[1]) + except Exception as e: + print(inference_output_path_dir) + raise Exception(str(e)) + + + + flow = padder.unpad(flow_pr[0]) + + flow = torch.tensor(flow) + flow_gt = torch.tensor(flow_gt) + + # 精度统计 + epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + # 精度统计 + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all<1) + px3 = np.mean(epe_all<3) + px5 = np.mean(epe_all<5) + + # 输出 + dstype = 'clean' + print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) + + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='GMA post process script') + parser.add_argument('--res', default='', type=str, metavar='PATH', + help='om result path') + parser.add_argument('--bin_src', default='', type=str, metavar='PATH', + help='bin src path') + + parser.add_argument('--info_name', default='', type=str, metavar='PATH', + help='info file path') + + parser.add_argument('--test_perf', default=False, action='store_true') + parser.add_argument('--onnx_path', default='', type=str) + parser.add_argument('--onnx_input_size', type=int, nargs='+', default=[440, 1024]) + + + args = parser.parse_args() + + if not args.test_perf: + postprocess(om_result_path=args.res, bin_src_path=args.bin_src, info_name=args.info_name) + else: + perf_g(args.onnx_path, args.onnx_input_size) \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..c893a2eb39caf235b3abc1fdc989af0c6f99c7bf --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_preprocess.py @@ -0,0 +1,766 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import torch.utils.data as data +from torchvision.transforms import ColorJitter +import torch.nn.functional as F +from PIL import Image + +import cv2 + +import numpy as np +import os +import os.path as osp +import random +from glob import glob +import re + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = osp.splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] + + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + self.occ_list = None + self.seg_list = None + self.seg_inv_list = None + + def __getitem__(self, index): + + src_image1_name = self.image_list[index][0] + src_image2_name = self.image_list[index][1] + + if self.is_test: + img1 = read_gen(self.image_list[index][0]) + img2 = read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return src_image1_name, src_image2_name, img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = readFlowKITTI(self.flow_list[index]) + else: + flow = read_gen(self.flow_list[index]) + + if self.occ_list is not None: + occ = read_gen(self.occ_list[index]) + occ = np.array(occ).astype(np.uint8) + occ = torch.from_numpy(occ // 255).bool() + + if self.seg_list is not None: + f_in = np.array(read_gen(self.seg_list[index])) + seg_r = f_in[:, :, 0].astype('int32') + seg_g = f_in[:, :, 1].astype('int32') + seg_b = f_in[:, :, 2].astype('int32') + seg_map = (seg_r * 256 + seg_g) * 256 + seg_b + seg_map = torch.from_numpy(seg_map) + + if self.seg_inv_list is not None: + seg_inv = read_gen(self.seg_inv_list[index]) + seg_inv = np.array(seg_inv).astype(np.uint8) + seg_inv = torch.from_numpy(seg_inv // 255).bool() + + img1 = read_gen(self.image_list[index][0]) + img2 = read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + if self.occ_list is not None: + return src_image1_name, src_image2_name, img1, img2, flow, valid.float(), occ, self.occ_list[index] + elif self.seg_list is not None and self.seg_inv_list is not None: + return src_image1_name, src_image2_name, img1, img2, flow, valid.float(), seg_map, seg_inv + else: + return src_image1_name, src_image2_name, img1, img2, flow, valid.float()#, self.extra_info[index] + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='/home/lq/Sintel/', dstype='clean', + occlusion=False, segmentation=False): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') # true + image_root = osp.join(root, split, dstype) # true + # occ_root = osp.join(root, split, 'occlusions') + # occ_root = osp.join(root, split, 'occ_plus_out') + # occ_root = osp.join(root, split, 'in_frame_occ') + occ_root = osp.join(root, split, 'out_of_frame') # false + + seg_root = osp.join(root, split, 'segmentation') # false + seg_inv_root = osp.join(root, split, 'segmentation_invalid') # false + self.segmentation = segmentation + self.occlusion = occlusion + if self.occlusion: + self.occ_list = [] + if self.segmentation: + self.seg_list = [] + self.seg_inv_list = [] + + # split == 'training' + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + if self.occlusion: + self.occ_list += sorted(glob(osp.join(occ_root, scene, '*.png'))) + if self.segmentation: + self.seg_list += sorted(glob(osp.join(seg_root, scene, '*.png'))) + self.seg_inv_list += sorted(glob(osp.join(seg_inv_root, scene, '*.png'))) + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def preprocess(src_path, save_path): + if not os.path.exists(src_path): + raise Exception("src_path not exists !") + + if not os.path.exists(save_path): + os.mkdir(save_path) + + for dstype in ['clean']: + sub_save_path = os.path.join(save_path,dstype) + if not os.path.exists(sub_save_path): + os.mkdir(sub_save_path) + # aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + # val_dataset = datasets.MpiSintel(aug_params, split='training', dstype=dstype, root=src_path) + val_dataset = MpiSintel(split='training', dstype=dstype, root=src_path) + for val_id in range(len(val_dataset)): + + img1_name, img2_name, image1, image2, flow_gt, valid_gt = val_dataset[val_id] + + image1 = image1[None] + image2 = image2[None] + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + # Example: img1_name = "/home/lq/Sintel/training/clean/alley_1/frame_0001.png" + scene = img1_name.split('/')[-2] + scene_dir = os.path.join(sub_save_path,scene) + file_name_1 = os.path.join(scene_dir, img1_name.split('/')[-1].split('.')[0] + ".bin") + file_name_2 = os.path.join(scene_dir, img2_name.split('/')[-1].split('.')[0] + ".bin") + + # Example: flow_gt_name = /home/lq/Sintel_bin/training/clean/alley_1/frame_0001_0002.bin + st_name = img1_name.split('/')[-1].split('.')[0].split('_')[1] + ed_name = img2_name.split('/')[-1].split('.')[0].split('_')[1] + flow_gt_name = img1_name.split('/')[-1].split('.')[0] + '_' + ed_name + valid_gt_name = flow_gt_name + "_valid_gt" + flow_gt_name = os.path.join(scene_dir, flow_gt_name + '.bin') + valid_gt_name = os.path.join(scene_dir, valid_gt_name + '.bin') + + + if not os.path.exists(scene_dir): + os.mkdir(scene_dir) + if not os.path.isfile(file_name_1): + image1 = image1.numpy() + image1.tofile(file_name_1) + if not os.path.isfile(file_name_2): + image2 = image2.numpy() + image2.tofile(file_name_2) + if not os.path.isfile(flow_gt_name): + flow_gt = flow_gt.numpy() + flow_gt.tofile(flow_gt_name) + # if not os.path.isfile(valid_gt_name): + # valid_gt = valid_gt.numpy() + # valid_gt.tofile(valid_gt_name) + + # img.tofile(os.path.join(save_path, file.split('.')[0] + ".bin")) + if (int(val_id * 1.0 / len(val_dataset) * 100) % 10 == 0): + print("Complete : {}%".format(int(val_id * 1.0 / len(val_dataset) * 100))) + # print("image1.shape = ",image1.shape) + # print("flow_gt.shape = ",flow_gt.shape) + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--src_path', type=str, default="/home/lq/Sintel/", help='dataset source path') + parser.add_argument('--save_path', type=str, default="/home/lq/Sintel_bin/", help='dataset bin save path') + + args = parser.parse_args() + + + torch.manual_seed(1234) + np.random.seed(1234) + + preprocess(src_path=args.src_path, save_path=args.save_path) + diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_pth2onnx.py b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_pth2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..0566ff959f649bc59f0242170f38c2b80e18c991 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/GMA_pth2onnx.py @@ -0,0 +1,89 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('GMA/core') +from network import RAFTGMA +import argparse +import torch +import onnx +import onnxsim +from collections import OrderedDict + + +def proc_nodes_module(checkpoint): + new_state_dict = OrderedDict() + for k, v in checkpoint.items(): + if "module." in k: + name = k.replace("module.", "") + else: + name = k + new_state_dict[name] = v + return new_state_dict + +def pth2onnx(input_file, output_file, args, opset_version=11): + + checkpoint = torch.load(input_file, map_location='cpu') + checkpoint = proc_nodes_module(checkpoint) + + model = RAFTGMA(args) + model.load_state_dict(checkpoint) + model.eval() + + dummy_input_0 = torch.randn(1, 3, args.image_size[0], args.image_size[1]) # 0 dimension --> bs + dummy_input_1 = torch.randn(1, 3, args.image_size[0], args.image_size[1]) # 0 dimension --> bs + + dummy_input = (dummy_input_0, dummy_input_1) + input_names = ['image1', 'image2'] + output_names = ['flow_pred'] + torch.onnx.export(model, dummy_input, output_file, input_names = input_names, output_names = output_names, opset_version=opset_version, verbose=True) + + onnx_model = onnx.load(output_file) + input_shapes = { + 'image1': [1,3,args.image_size[0],args.image_size[1]], + 'image2': [1,3,args.image_size[0],args.image_size[1]] + } + model_simp, check = onnxsim.simplify(onnx_model, input_shapes=input_shapes) + assert check, "Simplified ONNX model could not be validated" + onnx.save(model_simp, output_file) + print("onnx simpify ok") + + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--image_size', type=int, nargs='+', default=[440, 1024]) + parser.add_argument('--num_heads', default=1, type=int, + help='number of heads in attention and aggregation') + parser.add_argument('--upsample-learn', action='store_true', default=False, + help='If True, use learned upsampling, otherwise, use bilinear upsampling.') + + parser.add_argument('--position_only', default=False, action='store_true', + help='only use position-wise attention') + parser.add_argument('--position_and_content', default=False, action='store_true', + help='use position and content-wise attention') + parser.add_argument('--output_file', type=str, default="output/120000_gma-sintel_ddp_bs1.onnx", help='.onnx output file') + parser.add_argument('--input_file', type=str, default="./120000_gma-sintel_ddp.pth", help='.pth input file') + parser.add_argument('--mixed_precision', default=False, action='store_true', + help='use mixed precision') + + + + args = parser.parse_args() + output_file = args.output_file + input_file = args.input_file + + pth2onnx(input_file=input_file, output_file=output_file, args=args, opset_version=11) \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/LICENSE b/ACL_PyTorch/contrib/cv/classfication/GMA/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..56ee3c8c4cc2b4b32e0975d17258f9ba515fdbcc --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/README.md b/ACL_PyTorch/contrib/cv/classfication/GMA/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28077dfdefdef9d2f773b2cf6aa9f48522816ddd --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/README.md @@ -0,0 +1,55 @@ +# GMA模型PyTorch离线推理指导 + +## 1 环境准备 + +1.安装必要的依赖,测试环境可能已经安装其中的一些不同版本的库了,故手动测试时不推荐使用该命令安装 + +``` +pip3.7 install -r requirements.txt +``` + +2.获取与修改开源模型代码 + +``` +git clone https://github.com/zacjiang/GMA.git +cd GMA +git reset 2f1fd29468a86a354d44dd25d107930b3f175043 --hard +git apply ../GMA.patch +cd .. +``` + +3.获取权重文件 + +权重文件为120000_gma-sintel_ddp_bs1.onnx + +[GMA训练pth权重文件(百度网盘下载,提取码:0opv)](https://pan.baidu.com/s/1BnljcJXloRVYHBaeBHECww) + +4.数据集 + +[Sintel数据集](http://sintel.is.tue.mpg.de/) + +从官网获取MPT-Sintel-complete.zip,将这个压缩包解压到创建的/opt/npu/Sintel文件夹。 + +5.[获取msame工具](https://gitee.com/ascend/tools/tree/master/msame) + +按msame工具的安装说明完成msame工具的安装 + + +## 2.离线推理 + +310上执行,执行时使npu-smi info查看设备状态,确保device空闲。 + +``` +bash test/pth2om.sh +bash test/eval_acc_perf.sh --datasets_path=/root/datasets/cityscapes +``` + +评测结果 + +| 模型 | pth精度 | 310精度 | 性能基准 | 310性能 | +| -------- | ------- | ------- | -------- | ---------------------- | +| GMA bs1 | AEPE: 0.471264 | AEPE: 0.506081 | 0.1153 fps | 0.1145 fps | + +备注: +1.onnx不支持grid_sample算子,参考mmcv的自定义算子grid_sample的测试等价代码bilinear_grid_sample进行替换 +2.由于分辨率大,内存的限制,模型暂不支持多batch \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/env.sh b/ACL_PyTorch/contrib/cv/classfication/GMA/env.sh new file mode 100644 index 0000000000000000000000000000000000000000..686e9e10272f220567555951d02f90bc38aa4752 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/env.sh @@ -0,0 +1,16 @@ +#! /bin/bash + +# CANN安装目录 +export install_path=/home/lq/Ascend616/ascend-toolkit/latest +export PATH=/usr/local/python3.7.5/bin:${install_path}/atc/ccec_compiler/bin:${install_path}/atc/bin:$PATH +export PYTHONPATH=${install_path}/atc/python/site-packages:$PYTHONPATH +export LD_LIBRARY_PATH=${install_path}/atc/lib64:${install_path}/acllib/lib64:$LD_LIBRARY_PATH +export ASCEND_OPP_PATH=${install_path}/opp +export ASCEND_AICPU_PATH=${install_path} +export TOOLCHAIN_HOME=${install_path}/toolkit +# 将atc日志打印到屏幕 +# export ASCEND_SLOG_PRINT_TO_STDOUT=1 +# 设置日志级别 +#export ASCEND_GLOBAL_LOG_LEVEL=0 #debug 0 --> info 1 --> warning 2 --> error 3 +# 开启ge dump图 +#export DUMP_GE_GRAPH=2 diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/get_info.py b/ACL_PyTorch/contrib/cv/classfication/GMA/get_info.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fa96932b0256de2d61f4bc92b0213b34cfde43 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/get_info.py @@ -0,0 +1,52 @@ +import os +import os.path as osp +import sys +from glob import glob + +def get_bin_info_valid_gt(file_path, info_name): + bin_images = glob(osp.join(file_path, "*.bin")) + bin_images.sort() + # Example : /home/lq/Sintel_bin/clean/alley_1/frame_0001.bin /home/lq/Sintel_bin/clean/alley_1/frame_0002.bin /home/lq/Sintel_bin/clean/alley_1/frame_0001_0002.bin clean/alley_1/frame_0001_0002 /home/lq/Sintel_bin/clean/alley_1/frame_0001_0002_valid_gt.bin + + with open(info_name, 'a+') as f: + idx = 3 + while(idx < len(bin_images)): + tmp = bin_images[idx-1].split('/')[4:] + tmp = "/".join(tmp) + output_path_sub = tmp[:-4] + content = " ".join([bin_images[idx-3], bin_images[idx], bin_images[idx-2], output_path_sub, bin_images[idx-1]]) + f.write(content) + f.write('\n') + idx += 3 + +def get_bin_info(file_path, info_name): + bin_images = glob(osp.join(file_path, "*.bin")) + bin_images.sort() + # Example : /home/lq/Sintel_bin/clean/alley_1/frame_0001.bin /home/lq/Sintel_bin/clean/alley_1/frame_0002.bin /home/lq/Sintel_bin/clean/alley_1/frame_0001_0002.bin clean/alley_1/frame_0001_0002 /home/lq/Sintel_bin/clean/alley_1/frame_0001_0002_valid_gt.bin + + with open(info_name, 'a+') as f: + idx = 2 + while(idx < len(bin_images)): + tmp = bin_images[idx-1].split('/')[-3:] + tmp = "/".join(tmp) + output_path_sub = tmp[:-4] + content = " ".join([bin_images[idx-2], bin_images[idx], bin_images[idx-1], output_path_sub]) + f.write(content) + f.write('\n') + idx += 2 + + +if __name__ == '__main__': + + file_type = sys.argv[1] + file_path = sys.argv[2] # example : "/home/lq/Sintel_bin_no_aug" + info_name = sys.argv[3] # example : "./get_info_no_aug.txt" + + if file_type == 'bin': + dstype = "clean" + scene_list = os.listdir(osp.join(file_path, dstype)) + scene_list.sort() + for scene in scene_list: + get_bin_info(osp.join(file_path,dstype,scene), info_name) + else: + raise Exception("Not support other types except bin.") \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/inference.sh b/ACL_PyTorch/contrib/cv/classfication/GMA/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..77f8288e7e56445fd330a69aa26a473d854314dd --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/inference.sh @@ -0,0 +1,37 @@ +#!/bin/bash +source ./env.sh +msame_path=/home/lq/tools/msame/out/msame + +for para in $* +do + if [[ $para == --om_path* ]];then + om_path=`echo ${para#*=}` + elif [[ $para == --output_path* ]];then + output_path=`echo ${para#*=}` + elif [[ $para == --info_name* ]];then + info_name=`echo ${para#*=}` + fi +done + +cat ${info_name} | while read LINE +do + # LINE : img1 img2 flow_gt_path output_path_sub + # Example : /home/lq/Sintel_bin/clean/alley_1/frame_0001.bin /home/lq/Sintel_bin/clean/alley_1/frame_0002.bin /home/lq/Sintel_bin/clean/alley_1/frame_0001_0002.bin clean/alley_1/frame_0001_0002 + # item_output_dir = om_inference_output/clean/alley_1/frame_0001_0002 + + IFS=' ' + inarray=(${LINE}) + img1=${inarray[0]} + img2=${inarray[1]} + flow_gt_path=${inarray[2]} + + # echo ${flow_gt_path%.*} + output_path_sub=${inarray[3]} + + input_path=${img1},${img2} + item_output_dir=${output_path}/${output_path_sub} + + mkdir -p ${item_output_dir} + + ${msame_path} --model ${om_path} --input ${input_path} --output ${item_output_dir} --outfmt BIN --loop 1 +done \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/requirements.txt b/ACL_PyTorch/contrib/cv/classfication/GMA/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3fa157b4f6f7f5a9ad704562eb33e2fb8ee03ac --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/requirements.txt @@ -0,0 +1,7 @@ +torch == 1.8.0 +torchvision == 0.9.0 +onnx == 1.9.0 +numpy == 1.21.2 +Pillow == 8.3.1 +opencv-python == 4.5.3.56 +onnx-simplifier == 0.3.10 \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/test/eval_acc_perf.sh b/ACL_PyTorch/contrib/cv/classfication/GMA/test/eval_acc_perf.sh new file mode 100644 index 0000000000000000000000000000000000000000..e11e8ab9d9891d1d0247e4b1f64553335720d203 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/test/eval_acc_perf.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +datasets_path="/opt/npu/Sintel/" +msame_path=/home/lq/tools/msame/out/msame + +for para in $* +do + if [[ $para == --datasets_path* ]]; then + datasets_path=`echo ${para#*=}` + fi +done + +arch=`uname -m` + +# 数据前处理 +rm -rf ./prep_dataset +python3.7 GMA_preprocess.py --src_path=${datasets_path} --save_path=./prep_data +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +# 生成推理使用的info +rm -rf ./GMA_prep_bin.info +python3.7 get_info.py bin ./prep_data/ ./GMA_prep_bin.info +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +source env.sh + +# 推理 +rm -rf ./result/dumpOutput_device0/ + +mkdir -p ./result/dumpOutput_device0/ +./inference.sh --om_path=./120000_gma-sintel_ddp_bs1.om --info_name=./GMA_prep_bin.info --output_path=./result/dumpOutput_device0/ +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +echo "====accuracy data====" +python3.7 GMA_postprocess.py --res ./result/dumpOutput_device0/ --bin_src ./prep_data/bin --info_name ./GMA_prep_bin.info +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +echo "====performance data====" +mkdir -p performance_out +${msame_path} --model ./120000_gma-sintel_ddp_bs1.om --output performance_out/ --outfmt BIN --loop 20 >& result/performance_out.log +python3.7 test/parse.py result/performance_out.log +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi +echo "success" diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/test/parse.py b/ACL_PyTorch/contrib/cv/classfication/GMA/test/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..392da13e5d7755de7056b8decf4c669dc50d7a08 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/test/parse.py @@ -0,0 +1,19 @@ +import sys +import json +import re + +if __name__ == "__main__": + # if sys.argv[1].endswith('.txt'): + # result_txt = sys.argv[1] + # with open(result_txt, 'r') as f: + + result_txt = sys.argv[1] + + with open(result_txt, 'r') as f: + content = f.readlines() + + txt_data_list = [float( ''.join( re.findall(r'time:(.*)ms',item) ).strip() ) for item in content if "Inference average time" in item] # ms + + second_per_sample = sum(txt_data_list) / len(txt_data_list) / 1000 + fps = 1 / second_per_sample * 4 + print(f"310 bs1 fps:{fps}") \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/test/perf_g.sh b/ACL_PyTorch/contrib/cv/classfication/GMA/test/perf_g.sh new file mode 100644 index 0000000000000000000000000000000000000000..1abea936d57ba101b5b530ee5355a237aafe1c6a --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/test/perf_g.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3.7 GMA_postprocess.py --test_perf --onnx_path 120000_gma-sintel_ddp_bs1.onnx \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/GMA/test/pth2om.sh b/ACL_PyTorch/contrib/cv/classfication/GMA/test/pth2om.sh new file mode 100644 index 0000000000000000000000000000000000000000..386a197bac9cab7f773971af38a038d98b14eedb --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/GMA/test/pth2om.sh @@ -0,0 +1,17 @@ +input_pth_file=./120000_gma-sintel_ddp.pth +onnx_file=./120000_gma-sintel_ddp_bs1.onnx +img_height=440 +img_width=1024 +batch_size=1 + + +rm -rf *.onnx +python3.7 GMA_pth2onnx.py --input_file=${input_pth_file} --output_file=${onnx_file} --image_size ${img_height} ${img_width} >& GMA_pth2onnx.log +rm -rf *.om +source env.sh +atc --framework=5 --model=${onnx_file} --output=${onnx_file%.onnx} --input_format=NCHW --input_shape="image1:${batch_size},3,${img_height},${img_width};image2:${batch_size},3,${img_height},${img_width}" --log=info --soc_version=Ascend310 --out_nodes='Reshape_17519:0' &> GMA_onnx2om.log +if [ -f ${onnx_file%.onnx}.om ]; then + echo "Success changing pth to om." +else + echo "Fail!" +fi \ No newline at end of file