public class Record {
private final String description;
private final Map<String, Double> features;
// constructor, getter, toString, equals and hashcode
}
public interface Distance {
double calculate(Map<String, Double> f1, Map<String, Double> f2);
}
public class EuclideanDistance implements Distance {
@Override
public double calculate(Map<String, Double> f1, Map<String, Double> f2) {
double sum = 0;
for (String key : f1.keySet()) {
Double v1 = f1.get(key);
Double v2 = f2.get(key);
if (v1 != null && v2 != null) {
sum += Math.pow(v1 - v2, 2);
}
}
return Math.sqrt(sum);
}
}
public class Centroid {
private final Map<String, Double> coordinates;
// constructors, getter, toString, equals and hashcode
}
public class KMeans {
private static final Random random = new Random();
public static Map<Centroid, List<Record>> fit(List<Record> records,
int k,
Distance distance,
int maxIterations) {
// omitted
}
}
private static List<Centroid> randomCentroids(List<Record> records, int k) {
List<Centroid> centroids = new ArrayList<>();
Map<String, Double> maxs = new HashMap<>();
Map<String, Double> mins = new HashMap<>();
for (Record record : records) {
record.getFeatures().forEach((key, value) -> {
// compares the value with the current max and choose the bigger value between them
maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);
// compare the value with the current min and choose the smaller value between them
mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
});
}
Set<String> attributes = records.stream()
.flatMap(e -> e.getFeatures().keySet().stream())
.collect(toSet());
for (int i = 0; i < k; i++) {
Map<String, Double> coordinates = new HashMap<>();
for (String attribute : attributes) {
double max = maxs.get(attribute);
double min = mins.get(attribute);
coordinates.put(attribute, random.nextDouble() * (max - min) + min);
}
centroids.add(new Centroid(coordinates));
}
return centroids;
}
private static Centroid nearestCentroid(Record record, List<Centroid> centroids, Distance distance) {
double minimumDistance = Double.MAX_VALUE;
Centroid nearest = null;
for (Centroid centroid : centroids) {
double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates());
if (currentDistance < minimumDistance) {
minimumDistance = currentDistance;
nearest = centroid;
}
}
return nearest;
}
private static void assignToCluster(Map<Centroid, List<Record>> clusters,
Record record,
Centroid centroid) {
clusters.compute(centroid, (key, list) -> {
if (list == null) {
list = new ArrayList<>();
}
list.add(record);
return list;
});
}
private static Centroid average(Centroid centroid, List<Record> records) {
if (records == null || records.isEmpty()) {
return centroid;
}
Map<String, Double> average = centroid.getCoordinates();
records.stream().flatMap(e -> e.getFeatures().keySet().stream())
.forEach(k -> average.put(k, 0.0));
for (Record record : records) {
record.getFeatures().forEach(
(k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue)
);
}
average.forEach((k, v) -> average.put(k, v / records.size()));
return new Centroid(average);
}
private static List<Centroid> relocateCentroids(Map<Centroid, List<Record>> clusters) {
return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList());
}
public static Map<Centroid, List<Record>> fit(List<Record> records,
int k,
Distance distance,
int maxIterations) {
List<Centroid> centroids = randomCentroids(records, k);
Map<Centroid, List<Record>> clusters = new HashMap<>();
Map<Centroid, List<Record>> lastState = new HashMap<>();
// iterate for a pre-defined number of times
for (int i = 0; i < maxIterations; i++) {
boolean isLastIteration = i == maxIterations - 1;
// in each iteration we should find the nearest centroid for each record
for (Record record : records) {
Centroid centroid = nearestCentroid(record, centroids, distance);
assignToCluster(clusters, record, centroid);
}
// if the assignments do not change, then the algorithm terminates
boolean shouldTerminate = isLastIteration || clusters.equals(lastState);
lastState = clusters;
if (shouldTerminate) {
break;
}
// at the end of each iteration we should relocate the centroids
centroids = relocateCentroids(clusters);
clusters = new HashMap<>();
}
return lastState;
}
Comments
Post a Comment