cc / app.py
rahideer's picture
Create app.py
49a4932 verified
raw
history blame
6.74 kB
import streamlit as st
import javalang
import torch
import torch.nn.functional as F
import re
from transformers import AutoTokenizer, AutoModel
import warnings
import pandas as pd
import zipfile
import os
# Set up page config
st.set_page_config(
page_title="Java Code Clone Detector (IJaDataset 2.1)",
page_icon="πŸ”",
layout="wide"
)
# Suppress warnings
warnings.filterwarnings("ignore")
# Constants
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATASET_PATH = "ijadataset2-1.zip" # Update this path if needed
# Initialize models with caching
@st.cache_resource
def load_models():
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
return tokenizer, model
except Exception as e:
st.error(f"Failed to load models: {str(e)}")
return None, None
@st.cache_resource
def load_dataset():
try:
# Extract dataset if needed
if not os.path.exists("Diverse_100K_Dataset"):
with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
zip_ref.extractall(".")
# Load sample pairs (modify this based on your dataset structure)
clone_pairs = []
base_path = "Diverse_100K_Dataset/Subject_CloneTypes_Directories"
# Example: Load one pair from each clone type
for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST"]:
type_path = os.path.join(base_path, clone_type)
if os.path.exists(type_path):
for root, _, files in os.walk(type_path):
if files:
# Take first two files as a pair
if len(files) >= 2:
with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1:
code1 = f1.read()
with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
code2 = f2.read()
clone_pairs.append({
"type": clone_type,
"code1": code1,
"code2": code2
})
break # Just take one pair per type for demo
return clone_pairs[:10] # Return first 10 pairs for demo
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
return []
tokenizer, code_model = load_models()
dataset_pairs = load_dataset()
# Normalization function
def normalize_code(code):
try:
code = re.sub(r'//.*', '', code) # Remove single-line comments
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments
code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace
return code
except Exception:
return code
# Embedding generation
def get_embedding(code):
try:
code = normalize_code(code)
inputs = tokenizer(
code,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding='max_length'
).to(DEVICE)
with torch.no_grad():
outputs = code_model(**inputs)
return outputs.last_hidden_state.mean(dim=1) # Pooled embedding
except Exception as e:
st.error(f"Error processing code: {str(e)}")
return None
# Comparison function
def compare_code(code1, code2):
if not code1 or not code2:
return None
with st.spinner('Analyzing code...'):
emb1 = get_embedding(code1)
emb2 = get_embedding(code2)
if emb1 is None or emb2 is None:
return None
with torch.no_grad():
similarity = F.cosine_similarity(emb1, emb2).item()
return similarity
# UI Elements
st.title("πŸ” Java Code Clone Detector (IJaDataset 2.1)")
st.markdown("""
Compare Java code snippets from the IJaDataset 2.1 using CodeBERT embeddings.
""")
# Dataset selector
selected_pair = None
if dataset_pairs:
pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys()))
selected_pair = pair_options[selected_option]
# Layout
col1, col2 = st.columns(2)
with col1:
code1 = st.text_area(
"First Java Code",
height=300,
value=selected_pair["code1"] if selected_pair else "",
help="Enter the first Java code snippet"
)
with col2:
code2 = st.text_area(
"Second Java Code",
height=300,
value=selected_pair["code2"] if selected_pair else "",
help="Enter the second Java code snippet"
)
# Threshold slider
threshold = st.slider(
"Clone Detection Threshold",
min_value=0.5,
max_value=1.0,
value=0.85,
step=0.01,
help="Adjust the similarity threshold for clone detection"
)
# Compare button
if st.button("Compare Code", type="primary"):
if tokenizer is None or code_model is None:
st.error("Models failed to load. Please check the logs.")
else:
similarity = compare_code(code1, code2)
if similarity is not None:
# Display results
st.subheader("Results")
# Progress bar for visualization
st.progress(similarity)
# Metrics columns
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Similarity Score", f"{similarity:.3f}")
with col2:
st.metric("Threshold", f"{threshold:.3f}")
with col3:
is_clone = similarity >= threshold
st.metric(
"Clone Detection",
"βœ… Clone" if is_clone else "❌ Not a Clone",
delta=f"{similarity-threshold:+.3f}"
)
# Show normalized code for debugging
with st.expander("Show normalized code"):
tab1, tab2 = st.tabs(["First Code", "Second Code"])
with tab1:
st.code(normalize_code(code1))
with tab2:
st.code(normalize_code(code2))
# Footer
st.markdown("---")
st.markdown("""
**Dataset Information**:
- Using IJaDataset 2.1 from Kaggle
- Contains 100K Java files with clone annotations
- Clone types: Type-1, Type-2, and Type-3 clones
""")