import torch import torch.nn.functional as F from torch_geometric.nn import GeneralConv from torch_geometric.data import Data import torch.nn as nn from torch.optim import AdamW from sklearn.metrics import f1_score from torch_geometric.utils import normalize_edge_index from torch_geometric.utils import degree class FeatureAlign(nn.Module): def __init__(self, query_feature_dim, llm_feature_dim, common_dim): super(FeatureAlign, self).__init__() self.query_transform = nn.Linear(query_feature_dim, common_dim) self.llm_transform = nn.Linear(llm_feature_dim, common_dim*2) self.task_transform = nn.Linear(llm_feature_dim, common_dim) def forward(self,task_id, query_features, llm_features): aligned_task_features = self.task_transform(task_id) aligned_query_features = self.query_transform(query_features) aligned_two_features=torch.cat([aligned_task_features,aligned_query_features], 1) aligned_llm_features = self.llm_transform(llm_features) aligned_features = torch.cat([aligned_two_features, aligned_llm_features], 0) return aligned_features class EncoderDecoderNet(torch.nn.Module): def __init__(self, query_feature_dim, llm_feature_dim, hidden_features, in_edges): super(EncoderDecoderNet, self).__init__() self.in_edges = in_edges self.model_align = FeatureAlign(query_feature_dim, llm_feature_dim, hidden_features) self.encoder_conv_1 = GeneralConv(in_channels=hidden_features* 2, out_channels=hidden_features* 2, in_edge_channels=in_edges) self.encoder_conv_2 = GeneralConv(in_channels=hidden_features* 2, out_channels=hidden_features* 2, in_edge_channels=in_edges) self.edge_mlp = nn.Linear(in_edges, in_edges) self.bn1 = nn.BatchNorm1d(hidden_features * 2) self.bn2 = nn.BatchNorm1d(hidden_features * 2) def forward(self, task_id, query_features, llm_features, edge_index, edge_mask=None, edge_can_see=None, edge_weight=None): if edge_mask is not None: edge_index_mask = edge_index[:, edge_can_see] edge_index_predict = edge_index[:, edge_mask] if edge_weight is not None: edge_weight_mask = edge_weight[edge_can_see] edge_weight_mask=F.leaky_relu(self.edge_mlp(edge_weight_mask.reshape(-1, self.in_edges))) edge_weight_mask = edge_weight_mask.reshape(-1,self.in_edges) x_ini = (self.model_align(task_id, query_features, llm_features)) x = F.leaky_relu(self.bn1(self.encoder_conv_1(x_ini, edge_index_mask, edge_attr=edge_weight_mask))) x = self.bn2(self.encoder_conv_2(x, edge_index_mask, edge_attr=edge_weight_mask)) # x[edge_index_predict[1]] = x[edge_index_predict[1]] + x_ini[edge_index_predict[1]] edge_predict = F.sigmoid( (x_ini[edge_index_predict[0]] * x[edge_index_predict[1]]).mean(dim=-1)) return edge_predict class form_data: def __init__(self,device): self.device = device def formulation(self,task_id,query_feature,llm_feature,org_node,des_node,edge_feature,label,edge_mask,combined_edge,train_mask,valide_mask,test_mask): query_features = torch.tensor(query_feature, dtype=torch.float).to(self.device) llm_features = torch.tensor(llm_feature, dtype=torch.float).to(self.device) task_id=torch.tensor(task_id, dtype=torch.float).to(self.device) query_indices = list(range(len(query_features))) llm_indices = [i + len(query_indices) for i in range(len(llm_features))] des_node=[(i+1 + org_node[-1]) for i in des_node] edge_index = torch.tensor([org_node, des_node], dtype=torch.long).to(self.device) edge_weight = torch.tensor(edge_feature, dtype=torch.float).reshape(-1,1).to(self.device) combined_edge=torch.tensor(combined_edge, dtype=torch.float).reshape(-1,2).to(self.device) combined_edge=torch.cat((edge_weight, combined_edge), dim=-1) data = Data(task_id=task_id,query_features=query_features, llm_features=llm_features, edge_index=edge_index, edge_attr=edge_weight,query_indices=query_indices, llm_indices=llm_indices,label=torch.tensor(label, dtype=torch.float).to(self.device),edge_mask=edge_mask,combined_edge=combined_edge, train_mask=train_mask,valide_mask=valide_mask,test_mask=test_mask,org_combine=combined_edge) return data class GNN_prediction: def __init__(self, query_feature_dim, llm_feature_dim,hidden_features_size,in_edges_size,config,device): self.model = EncoderDecoderNet(query_feature_dim=query_feature_dim, llm_feature_dim=llm_feature_dim, hidden_features=hidden_features_size,in_edges=in_edges_size).to(device) self.config = config def test(self,data,model_path, llm_names): state_dict = torch.load(model_path, map_location='cpu') self.model.load_state_dict(state_dict) self.model.eval() mask = data.edge_mask.clone().to(torch.bool) edge_can_see = torch.logical_or(data.train_mask, data.valide_mask) with torch.no_grad(): edge_predict = self.model(task_id=data.task_id,query_features=data.query_features, llm_features=data.llm_features, edge_index=data.edge_index, edge_mask=mask,edge_can_see=edge_can_see,edge_weight=data.combined_edge) edge_predict = edge_predict.reshape(-1, self.config['llm_num']) max_idx = torch.argmax(edge_predict, 1) value_test = data.edge_attr[mask].reshape(-1, self.config['llm_num']) probs = torch.softmax(edge_predict, dim=1) max_idx = torch.multinomial(probs, 1).item() best_llm = llm_names[max_idx] ## map correct API print(best_llm) return best_llm