| | import torch |
| | from torch import nn |
| | import sys |
| | sys.path.append('rscd') |
| | from utils.build import build_from_cfg |
| |
|
| | class myModel(nn.Module): |
| | def __init__(self, cfg): |
| | super(myModel, self).__init__() |
| | self.backbone = build_from_cfg(cfg.backbone) |
| | self.decoderhead = build_from_cfg(cfg.decoderhead) |
| | |
| | def forward(self, x1, x2, gtmask=None): |
| | backbone_outputs = self.backbone(x1, x2) |
| | if gtmask == None: |
| | x_list = self.decoderhead(backbone_outputs) |
| | else: |
| | x_list = self.decoderhead(backbone_outputs, gtmask) |
| | return x_list |
| |
|
| | """ |
| | 对于不满足该范式的模型可在backbone部分进行定义, 并在此处导入 |
| | """ |
| |
|
| | |
| | def build_model(cfg): |
| | c = myModel(cfg) |
| | return c |
| |
|
| |
|
| | if __name__ == "__main__": |
| | x1 = torch.randn(4, 3, 512, 512) |
| | x2 = torch.randn(4, 3, 512, 512) |
| | target = torch.randint(low=0,high=2,size=[4, 512, 512]) |
| | file_path = r"E:\zjuse\2308CD\rschangedetection\configs\SARASNet.py" |
| |
|
| | from utils.config import Config |
| | from rscd.losses import build_loss |
| |
|
| | cfg = Config.fromfile(file_path) |
| | net = build_model(cfg.model_config) |
| | res = net(x1, x2) |
| | print(res.shape) |
| | loss = build_loss(cfg.loss_config) |
| |
|
| | compute = loss(res,target) |
| | print(compute) |