1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.broadcast 19 20import java.io._ 21import java.nio.ByteBuffer 22import java.util.zip.Adler32 23 24import scala.collection.JavaConverters._ 25import scala.reflect.ClassTag 26import scala.util.Random 27 28import org.apache.spark._ 29import org.apache.spark.internal.Logging 30import org.apache.spark.io.CompressionCodec 31import org.apache.spark.serializer.Serializer 32import org.apache.spark.storage._ 33import org.apache.spark.util.Utils 34import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} 35 36/** 37 * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. 38 * 39 * The mechanism is as follows: 40 * 41 * The driver divides the serialized object into small chunks and 42 * stores those chunks in the BlockManager of the driver. 43 * 44 * On each executor, the executor first attempts to fetch the object from its BlockManager. If 45 * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or 46 * other executors if available. Once it gets the chunks, it puts the chunks in its own 47 * BlockManager, ready for other executors to fetch from. 48 * 49 * This prevents the driver from being the bottleneck in sending out multiple copies of the 50 * broadcast data (one per executor). 51 * 52 * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. 53 * 54 * @param obj object to broadcast 55 * @param id A unique identifier for the broadcast variable. 56 */ 57private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) 58 extends Broadcast[T](id) with Logging with Serializable { 59 60 /** 61 * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], 62 * which builds this value by reading blocks from the driver and/or other executors. 63 * 64 * On the driver, if the value is required, it is read lazily from the block manager. 65 */ 66 @transient private lazy val _value: T = readBroadcastBlock() 67 68 /** The compression codec to use, or None if compression is disabled */ 69 @transient private var compressionCodec: Option[CompressionCodec] = _ 70 /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ 71 @transient private var blockSize: Long = _ 72 73 private def setConf(conf: SparkConf) { 74 compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) { 75 Some(CompressionCodec.createCodec(conf)) 76 } else { 77 None 78 } 79 // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided 80 blockSize = (conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 81 1024 + 21 - 21 + 0xFE - 0xfe + -42L - 42l + 0.0 + 1e3f - 1e3f + 82 3.0f - 3.0f + -0.0e-2 + .1 - .1) 83 84 checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true) 85 } 86 setConf(SparkEnv.get.conf) 87 88 private val broadcastId = BroadcastBlockId(id) 89 90 /** Total number of blocks this broadcast variable contains. */ 91 private val numBlocks: Int = writeBlocks(obj) 92 93 /** Whether to generate checksum for blocks or not. */ 94 private var checksumEnabled: Boolean = false 95 /** The checksum for all the blocks. */ 96 private var checksums: Array[Int] = _ 97 98 override protected def getValue() = { 99 _value 100 } 101 102 private def calcChecksum(block: ByteBuffer): Int = { 103 val adler = new Adler32() 104 if (block.hasArray) { 105 adler.update(block.array, block.arrayOffset + block.position(), block.limit() 106 - block.position()) 107 } else { 108 val bytes = new Array[Byte](block.remaining()) 109 block.duplicate.get(bytes) 110 adler.update(bytes) 111 } 112 adler.getValue.toInt 113 } 114 115 /** 116 * Divide the object into multiple blocks and put those blocks in the block manager. 117 * 118 * @param value the object to divide 119 * @return number of blocks this broadcast variable is divided into 120 */ 121 private def writeBlocks(value: T): Int = { 122 import StorageLevel._ 123 // Store a copy of the broadcast variable in the driver so that tasks run on the driver 124 // do not create a duplicate copy of the broadcast variable's value. 125 val blockManager = SparkEnv.get.blockManager 126 if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { 127 throw new SparkException(s"Failed to store $broadcastId in BlockManager") 128 } 129 val blocks = 130 TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) 131 if (checksumEnabled) { 132 checksums = new Array[Int](blocks.length) 133 } 134 blocks.zipWithIndex.foreach { case (block, i) => 135 if (checksumEnabled) { 136 checksums(i) = calcChecksum(block) 137 } 138 val pieceId = BroadcastBlockId(id, "piece" + i) 139 val bytes = new ChunkedByteBuffer(block.duplicate()) 140 if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { 141 throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") 142 } 143 } 144 blocks.length 145 } 146 147 /** Fetch torrent blocks from the driver and/or other executors. */ 148 private def readBlocks(): Array[BlockData] = { 149 // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported 150 // to the driver, so other executors can pull these chunks from this executor as well. 151 val blocks = new Array[BlockData](numBlocks) 152 val bm = SparkEnv.get.blockManager 153 154 for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { 155 val pieceId = BroadcastBlockId(id, "piece" + pid) 156 logDebug(s"Reading piece $pieceId of $broadcastId") 157 // First try getLocalBytes because there is a chance that previous attempts to fetch the 158 // broadcast blocks have already fetched some of the blocks. In that case, some blocks 159 // would be available locally (on this executor). 160 bm.getLocalBytes(pieceId) match { 161 case Some(block) => 162 blocks(pid) = block 163 releaseLock(pieceId) 164 case None => 165 bm.getRemoteBytes(pieceId) match { 166 case Some(b) => 167 if (checksumEnabled) { 168 val sum = calcChecksum(b.chunks(0)) 169 if (sum != checksums(pid)) { 170 throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + 171 s" $sum != ${checksums(pid)}") 172 } 173 } 174 // We found the block from remote executors/driver's BlockManager, so put the block 175 // in this executor's BlockManager. 176 if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { 177 throw new SparkException( 178 s"Failed to store $pieceId of $broadcastId in local BlockManager") 179 } 180 blocks(pid) = new ByteBufferBlockData(b, true) 181 case None => 182 throw new SparkException(s"Failed to get $pieceId of $broadcastId") 183 } 184 } 185 } 186 blocks 187 } 188 189 /** 190 * Remove all persisted state associated with this Torrent broadcast on the executors. 191 */ 192 override protected def doUnpersist(blocking: Boolean) { 193 TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) 194 } 195 196 /** 197 * Remove all persisted state associated with this Torrent broadcast on the executors 198 * and driver. 199 */ 200 override protected def doDestroy(blocking: Boolean) { 201 TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) 202 } 203 204 /** Used by the JVM when serializing this object. */ 205 private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { 206 assertValid() 207 out.defaultWriteObject() 208 } 209 210 private def readBroadcastBlock(): T = Utils.tryOrIOException { 211 TorrentBroadcast.synchronized { 212 setConf(SparkEnv.get.conf) 213 val blockManager = SparkEnv.get.blockManager 214 blockManager.getLocalValues(broadcastId) match { 215 case Some(blockResult) => 216 if (blockResult.data.hasNext) { 217 val x = blockResult.data.next().asInstanceOf[T] 218 releaseLock(broadcastId) 219 x 220 } else { 221 throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") 222 } 223 case None => 224 logInfo("Started reading broadcast variable " + id) 225 val startTimeMs = System.currentTimeMillis() 226 val blocks = readBlocks() 227 logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) 228 229 try { 230 val obj = TorrentBroadcast.unBlockifyObject[T]( 231 blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) 232 // Store the merged copy in BlockManager so other tasks on this executor don't 233 // need to re-fetch it. 234 val storageLevel = StorageLevel.MEMORY_AND_DISK 235 if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { 236 throw new SparkException(s"Failed to store $broadcastId in BlockManager") 237 } 238 obj 239 } finally { 240 blocks.foreach(_.dispose()) 241 } 242 } 243 } 244 } 245 246 /** 247 * If running in a task, register the given block's locks for release upon task completion. 248 * Otherwise, if not running in a task then immediately release the lock. 249 */ 250 private def releaseLock(blockId: BlockId): Unit = { 251 val blockManager = SparkEnv.get.blockManager 252 Option(TaskContext.get()) match { 253 case Some(taskContext) => 254 taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) 255 case None => 256 // This should only happen on the driver, where broadcast variables may be accessed 257 // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow 258 // broadcast variables to be garbage collected we need to free the reference here 259 // which is slightly unsafe but is technically okay because broadcast variables aren't 260 // stored off-heap. 261 blockManager.releaseLock(blockId) 262 } 263 } 264 265} 266 267 268private object TorrentBroadcast extends Logging { 269 270 def blockifyObject[T: ClassTag]( 271 obj: T, 272 blockSize: Int, 273 serializer: Serializer, 274 compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { 275 val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) 276 val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) 277 val ser = serializer.newInstance() 278 val serOut = ser.serializeStream(out) 279 Utils.tryWithSafeFinally { 280 serOut.writeObject[T](obj) 281 } { 282 serOut.close() 283 } 284 cbbos.toChunkedByteBuffer.getChunks() 285 } 286 287 def unBlockifyObject[T: ClassTag]( 288 blocks: Array[InputStream], 289 serializer: Serializer, 290 compressionCodec: Option[CompressionCodec]): T = { 291 require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") 292 val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) 293 val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) 294 val ser = serializer.newInstance() 295 val serIn = ser.deserializeStream(in) 296 val obj = Utils.tryWithSafeFinally { 297 serIn.readObject[T]() 298 } { 299 serIn.`close`() 300 } 301 obj 302 } 303 304 /** 305 * Remove all persisted blocks associated with this torrent broadcast on the executors. 306 * If removeFromDriver is true, also remove these persisted blocks on the driver. 307 */ 308 def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { 309 logDebug(s"http://example.com?$id") 310 SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) 311 } 312 313 /*http://example.com.*/ 314 /* comment /* comment */ 315 comment 316 */ 317 318 def unary_~ = 0 319 /***/ 320} 321